diff --git a/Cargo.lock b/Cargo.lock index 15cc8bc..0b943b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,20 @@ dependencies = [ "winapi", ] +[[package]] +name = "assert_cmd" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54f002ce7d0c5e809ebb02be78fd503aeed4a511fd0fcaff6e6914cbdabbfa33" +dependencies = [ + "bstr", + "doc-comment", + "predicates", + "predicates-core", + "predicates-tree", + "wait-timeout", +] + [[package]] name = "atty" version = "0.2.14" @@ -52,6 +66,17 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bstr" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90682c8d613ad3373e66de8c6411e0ae2ab2571e879d2efbf73558cc66f21279" +dependencies = [ + "lazy_static", + "memchr", + "regex-automata", +] + [[package]] name = "bumpalo" version = "3.7.0" @@ -168,6 +193,12 @@ dependencies = [ "syn", ] +[[package]] +name = "difflib" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" + [[package]] name = "digest" version = "0.9.0" @@ -181,6 +212,7 @@ dependencies = [ name = "distant" version = "0.13.0" dependencies = [ + "assert_cmd", "bytes", "derive_more", "flexi_logger", @@ -197,6 +229,7 @@ dependencies = [ "serde_json", "structopt", "strum", + "tempfile", "tokio", "tokio-stream", "tokio-util", @@ -204,6 +237,12 @@ dependencies = [ "whoami", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "ecdsa" version = "0.12.3" @@ -216,6 +255,12 @@ dependencies = [ "signature", ] +[[package]] +name = "either" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" + [[package]] name = "elliptic-curve" version = "0.10.5" @@ -448,6 +493,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.7" @@ -648,6 +702,33 @@ version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" +[[package]] +name = "predicates" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c143348f141cc87aab5b950021bac6145d0e5ae754b0591de23244cee42c9308" +dependencies = [ + "difflib", + "itertools", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57e35a3326b75e49aa85f5dc6ec15b41108cf5aee58eabb1f274dd18b73c2451" + +[[package]] +name = "predicates-tree" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7dd0fd014130206c9352efbdc92be592751b2b9274dff685348341082c6ea3d" +dependencies = [ + "predicates-core", + "treeline", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -762,12 +843,27 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" + [[package]] name = "regex-syntax" version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +[[package]] +name = "remove_dir_all" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" +dependencies = [ + "winapi", +] + [[package]] name = "ryu" version = "1.0.5" @@ -951,6 +1047,20 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "tempfile" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22" +dependencies = [ + "cfg-if", + "libc", + "rand", + "redox_syscall", + "remove_dir_all", + "winapi", +] + [[package]] name = "textwrap" version = "0.11.0" @@ -1047,6 +1157,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "treeline" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7f741b240f1a48843f9b8e0444fb55fb2a4ff67293b50a9179dfd5ea67f8d41" + [[package]] name = "typenum" version = "1.13.0" @@ -1083,6 +1199,15 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" +[[package]] +name = "wait-timeout" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +dependencies = [ + "libc", +] + [[package]] name = "walkdir" version = "2.3.2" diff --git a/Cargo.toml b/Cargo.toml index 7bc2ae8..ca41e09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,3 +39,7 @@ fork = "0.1.18" lazy_static = "1.4.0" structopt = "0.3.22" whoami = "1.1.2" + +[dev-dependencies] +assert_cmd = "2.0.0" +tempfile = "3.2.0" diff --git a/src/cli/exit.rs b/src/cli/exit.rs index c911bf4..5e98daa 100644 --- a/src/cli/exit.rs +++ b/src/cli/exit.rs @@ -1,38 +1,41 @@ -use crate::core::net::TransportError; +use crate::core::{client::RemoteProcessError, net::TransportError}; /// Exit codes following https://www.freebsd.org/cgi/man.cgi?query=sysexits&sektion=3 #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum ExitCode { /// EX_USAGE (64) - being used when arguments missing or bad arguments provided to CLI - Usage = 64, + Usage, /// EX_DATAERR (65) - being used when bad data received not in UTF-8 format or transport data /// is bad - DataErr = 65, + DataErr, /// EX_NOINPUT (66) - being used when not getting expected data from launch - NoInput = 66, + NoInput, /// EX_NOHOST (68) - being used when failed to resolve a host - NoHost = 68, + NoHost, /// EX_UNAVAILABLE (69) - being used when IO error encountered where connection is problem - Unavailable = 69, + Unavailable, /// EX_SOFTWARE (70) - being used for internal errors that can occur like joining a task - Software = 70, + Software, /// EX_OSERR (71) - being used when fork failed - OsErr = 71, + OsErr, /// EX_IOERR (74) - being used as catchall for IO errors - IoError = 74, + IoError, /// EX_TEMPFAIL (75) - being used when we get a timeout - TempFail = 75, + TempFail, /// EX_PROTOCOL (76) - being used as catchall for transport errors - Protocol = 76, + Protocol, + + /// Custom exit code to pass back verbatim + Custom(i32), } /// Represents an error that can be converted into an exit code @@ -40,7 +43,19 @@ pub trait ExitCodeError: std::error::Error { fn to_exit_code(&self) -> ExitCode; fn to_i32(&self) -> i32 { - self.to_exit_code() as i32 + match self.to_exit_code() { + ExitCode::Usage => 64, + ExitCode::DataErr => 65, + ExitCode::NoInput => 66, + ExitCode::NoHost => 68, + ExitCode::Unavailable => 69, + ExitCode::Software => 70, + ExitCode::OsErr => 71, + ExitCode::IoError => 74, + ExitCode::TempFail => 75, + ExitCode::Protocol => 76, + ExitCode::Custom(x) => x, + } } } @@ -68,6 +83,19 @@ impl ExitCodeError for TransportError { } } +impl ExitCodeError for RemoteProcessError { + fn to_exit_code(&self) -> ExitCode { + match self { + Self::BadResponse => ExitCode::DataErr, + Self::ChannelDead => ExitCode::Unavailable, + Self::Overloaded => ExitCode::Software, + Self::TransportError(x) => x.to_exit_code(), + Self::UnexpectedEof => ExitCode::IoError, + Self::WaitFailed(_) => ExitCode::Software, + } + } +} + impl From for Box { fn from(x: T) -> Self { Box::new(x) diff --git a/src/cli/link.rs b/src/cli/link.rs new file mode 100644 index 0000000..da850ab --- /dev/null +++ b/src/cli/link.rs @@ -0,0 +1,112 @@ +use crate::{ + cli, + core::{ + client::{ + RemoteLspStderr, RemoteLspStdin, RemoteLspStdout, RemoteStderr, RemoteStdin, + RemoteStdout, + }, + constants::MAX_PIPE_CHUNK_SIZE, + }, +}; +use std::{ + io::{self, Write}, + thread, +}; +use tokio::task::{JoinError, JoinHandle}; + +/// Represents a link between a remote process' stdin/stdout/stderr and this process' +/// stdin/stdout/stderr +pub struct RemoteProcessLink { + _stdin_thread: thread::JoinHandle<()>, + stdin_task: JoinHandle>, + stdout_task: JoinHandle>, + stderr_task: JoinHandle>, +} + +macro_rules! from_pipes { + ($stdin:expr, $stdout:expr, $stderr:expr) => {{ + let (stdin_thread, mut stdin_rx) = cli::stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE); + let stdin_task = tokio::spawn(async move { + loop { + if let Some(input) = stdin_rx.recv().await { + if let Err(x) = $stdin.write(input.as_str()).await { + break Err(x); + } + } else { + break Ok(()); + } + } + }); + let stdout_task = tokio::spawn(async move { + let handle = io::stdout(); + loop { + match $stdout.read().await { + Ok(output) => { + let mut out = handle.lock(); + out.write_all(output.as_bytes())?; + out.flush()?; + } + Err(x) => break Err(x), + } + } + }); + let stderr_task = tokio::spawn(async move { + let handle = io::stderr(); + loop { + match $stderr.read().await { + Ok(output) => { + let mut out = handle.lock(); + out.write_all(output.as_bytes())?; + out.flush()?; + } + Err(x) => break Err(x), + } + } + }); + + RemoteProcessLink { + _stdin_thread: stdin_thread, + stdin_task, + stdout_task, + stderr_task, + } + }}; +} + +impl RemoteProcessLink { + /// Creates a new process link from the pipes of a remote process + pub fn from_remote_pipes( + mut stdin: RemoteStdin, + mut stdout: RemoteStdout, + mut stderr: RemoteStderr, + ) -> Self { + from_pipes!(stdin, stdout, stderr) + } + + /// Creates a new process link from the pipes of a remote LSP server process + pub fn from_remote_lsp_pipes( + mut stdin: RemoteLspStdin, + mut stdout: RemoteLspStdout, + mut stderr: RemoteLspStderr, + ) -> Self { + from_pipes!(stdin, stdout, stderr) + } + + /// Shuts down the link, aborting any running tasks, and swallowing join errors + pub async fn shutdown(self) { + self.abort(); + let _ = self.wait().await; + } + + /// Waits for the stdin, stdout, and stderr tasks to complete + pub async fn wait(self) -> Result<(), JoinError> { + tokio::try_join!(self.stdin_task, self.stdout_task, self.stderr_task).map(|_| ()) + } + + /// Aborts the link by aborting tasks processing stdin, stdout, and stderr + pub fn abort(&self) { + self.stdin_task.abort(); + self.stdout_task.abort(); + self.stderr_task.abort(); + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 53581f0..1d28b96 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -1,10 +1,13 @@ mod buf; mod exit; +mod link; mod opt; mod output; mod session; +mod stdin; mod subcommand; pub use exit::{ExitCode, ExitCodeError}; pub use opt::*; pub use output::ResponseOut; +pub use session::CliSession; diff --git a/src/cli/session.rs b/src/cli/session.rs index 817f02d..b93d1ae 100644 --- a/src/cli/session.rs +++ b/src/cli/session.rs @@ -1,57 +1,84 @@ use crate::{ - cli::{buf::StringBuf, Format, ResponseOut}, + cli::{buf::StringBuf, stdin, Format, ResponseOut}, core::{ client::Session, constants::MAX_PIPE_CHUNK_SIZE, - data::{Request, Response}, + data::{Request, RequestData, Response}, net::DataStream, }, }; use log::*; -use std::{ - io::{self, BufReader, Read}, - sync::Arc, - thread, -}; -use tokio::sync::{mpsc, watch}; +use std::{io, thread}; +use structopt::StructOpt; +use tokio::{sync::mpsc, task::JoinHandle}; use tokio_stream::{wrappers::BroadcastStream, StreamExt}; /// Represents a wrapper around a session that provides CLI functionality such as reading from /// stdin and piping results back out to stdout -pub struct CliSession -where - T: DataStream, -{ - inner: Session, +pub struct CliSession { + stdin_thread: thread::JoinHandle<()>, + req_task: JoinHandle<()>, + res_task: JoinHandle>, } -impl CliSession -where - T: DataStream, -{ - pub fn new(inner: Session) -> Self { - Self { inner } +impl CliSession { + pub fn new(tenant: String, mut session: Session, format: Format) -> Self + where + T: DataStream + 'static, + { + let (stdin_thread, stdin_rx) = stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE); + + let (exit_tx, exit_rx) = mpsc::channel(1); + let stream = session.to_response_broadcast_stream(); + let res_task = + tokio::spawn(async move { process_incoming_responses(stream, format, exit_rx).await }); + + let map_line = move |line: &str| match format { + Format::Json => serde_json::from_str(&line) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x)), + Format::Shell => { + let data = RequestData::from_iter_safe( + std::iter::once("distant") + .chain(line.trim().split(' ').filter(|s| !s.trim().is_empty())), + ) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x)); + + data.map(|x| Request::new(tenant.to_string(), vec![x])) + } + }; + let req_task = tokio::spawn(async move { + process_outgoing_requests(session, stdin_rx, exit_tx, format, map_line).await + }); + + Self { + stdin_thread, + req_task, + res_task, + } } -} -// TODO TODO TODO: -// -// 1. Change watch to broadcast if going to use in both loops, otherwise just make -// it an mpsc otherwise -// 2. Need to provide outgoing requests function with logic from inner.rs to create a request -// based on the format (json or shell), where json uses serde_json::from_str and shell -// uses Request::new(tenant.as_str(), vec![RequestData::from_iter_safe(...)]) -// 3. Need to add a wait method to block on the running tasks -// 4. Need to add an abort method to abort the tasks -// 5. Is there any way to deal with the blocking thread for stdin to kill it? This isn't a big -// deal as the shutdown would only be happening on client termination anyway, but still... + /// Wait for the cli session to terminate + pub async fn wait(self) -> io::Result<()> { + match tokio::try_join!(self.req_task, self.res_task) { + Ok((_, res)) => res, + Err(x) => Err(io::Error::new(io::ErrorKind::BrokenPipe, x)), + } + } + + /// Aborts the cli session forcing its task handlers to abort underneath, which means that a + /// call to `wait` will return an error + pub async fn abort(&self) { + self.req_task.abort(); + self.res_task.abort(); + } +} /// Helper function that loops, processing incoming responses not tied to a request to be sent out /// over stdout/stderr async fn process_incoming_responses( mut stream: BroadcastStream, format: Format, - mut exit: watch::Receiver, + mut exit: mpsc::Receiver<()>, ) -> io::Result<()> { loop { tokio::select! { @@ -62,7 +89,7 @@ async fn process_incoming_responses( None => return Ok(()), } } - _ = exit.changed() => { + _ = exit.recv() => { return Ok(()); } } @@ -74,6 +101,7 @@ async fn process_incoming_responses( async fn process_outgoing_requests( mut session: Session, mut stdin_rx: mpsc::Receiver, + exit_tx: mpsc::Sender<()>, format: Format, map_line: F, ) where @@ -90,9 +118,16 @@ async fn process_outgoing_requests( // For each complete line, parse into a request if let Some(lines) = lines { - for line in lines.lines() { + for line in lines.lines().map(|line| line.trim()) { trace!("Processing line: {:?}", line); - if line.trim().is_empty() { + if line.is_empty() { + continue; + } else if line == "exit" { + debug!("Got exit request, so closing cli session"); + stdin_rx.close(); + if let Err(_) = exit_tx.send(()).await { + error!("Failed to close cli session"); + } continue; } @@ -114,42 +149,3 @@ async fn process_outgoing_requests( } } } - -/// Creates a new thread that performs stdin reads in a blocking fashion, returning -/// a handle to the thread and a receiver that will be sent input as it becomes available -fn spawn_stdin_reader() -> (thread::JoinHandle<()>, mpsc::Receiver) { - let (tx, rx) = mpsc::channel(1); - - // NOTE: Using blocking I/O per tokio's advice to read from stdin line-by-line and then - // pass the results to a separate async handler to forward to the remote process - let handle = thread::spawn(move || { - let mut stdin = BufReader::new(io::stdin()); - - // Maximum chunk that we expect to read at any one time - let mut buf = [0; MAX_PIPE_CHUNK_SIZE]; - - loop { - match stdin.read(&mut buf) { - Ok(0) | Err(_) => break, - Ok(n) => { - match String::from_utf8(buf[..n].to_vec()) { - Ok(text) => { - if let Err(x) = tx.blocking_send(text) { - error!( - "Failed to pass along stdin to be sent to remote process: {}", - x - ); - } - } - Err(x) => { - error!("Input over stdin is invalid: {}", x); - } - } - thread::yield_now(); - } - } - } - }); - - (handle, rx) -} diff --git a/src/cli/stdin.rs b/src/cli/stdin.rs new file mode 100644 index 0000000..d13ded6 --- /dev/null +++ b/src/cli/stdin.rs @@ -0,0 +1,43 @@ +use log::error; +use std::{ + io::{self, BufReader, Read}, + thread, +}; +use tokio::sync::mpsc; + +/// Creates a new thread that performs stdin reads in a blocking fashion, returning +/// a handle to the thread and a receiver that will be sent input as it becomes available +pub fn spawn_channel(buffer: usize) -> (thread::JoinHandle<()>, mpsc::Receiver) { + let (tx, rx) = mpsc::channel(1); + + // NOTE: Using blocking I/O per tokio's advice to read from stdin line-by-line and then + // pass the results to a separate async handler to forward to the remote process + let handle = thread::spawn(move || { + let mut stdin = BufReader::new(io::stdin()); + + // Maximum chunk that we expect to read at any one time + let mut buf = vec![0; buffer]; + + loop { + match stdin.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + match String::from_utf8(buf[..n].to_vec()) { + Ok(text) => { + if let Err(x) = tx.blocking_send(text) { + error!("Stdin channel closed: {}", x); + break; + } + } + Err(x) => { + error!("Input over stdin is invalid: {}", x); + } + } + thread::yield_now(); + } + } + } + }); + + (handle, rx) +} diff --git a/src/cli/subcommand/action.rs b/src/cli/subcommand/action.rs new file mode 100644 index 0000000..138a2c8 --- /dev/null +++ b/src/cli/subcommand/action.rs @@ -0,0 +1,186 @@ +use crate::{ + cli::{ + link::RemoteProcessLink, + opt::{ActionSubcommand, CommonOpt, SessionInput}, + CliSession, ExitCode, ExitCodeError, ResponseOut, + }, + core::{ + client::{ + self, LspData, RemoteProcess, RemoteProcessError, Session, SessionInfo, SessionInfoFile, + }, + data::{Request, RequestData}, + net::{DataStream, TransportError}, + }, +}; +use derive_more::{Display, Error, From}; +use tokio::{io, time::Duration}; + +#[derive(Debug, Display, Error, From)] +pub enum Error { + #[display(fmt = "Process failed with exit code: {}", _0)] + BadProcessExit(#[error(not(source))] i32), + IoError(io::Error), + #[display(fmt = "Non-interactive but no operation supplied")] + MissingOperation, + RemoteProcessError(RemoteProcessError), + TransportError(TransportError), +} + +impl ExitCodeError for Error { + fn to_exit_code(&self) -> ExitCode { + match self { + Self::BadProcessExit(x) => ExitCode::Custom(*x), + Self::IoError(x) => x.to_exit_code(), + Self::MissingOperation => ExitCode::Usage, + Self::RemoteProcessError(x) => x.to_exit_code(), + Self::TransportError(x) => x.to_exit_code(), + } + } +} + +pub fn run(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> { + let rt = tokio::runtime::Runtime::new()?; + + rt.block_on(async { run_async(cmd, opt).await }) +} + +async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> { + let timeout = opt.to_timeout_duration(); + + match cmd.session { + SessionInput::Environment => { + start( + cmd, + Session::tcp_connect_timeout(SessionInfo::from_environment()?, timeout).await?, + timeout, + None, + ) + .await + } + SessionInput::File => { + let path = cmd.session_data.session_file.clone(); + start( + cmd, + Session::tcp_connect_timeout( + SessionInfoFile::load_from(path).await?.into(), + timeout, + ) + .await?, + timeout, + None, + ) + .await + } + SessionInput::Pipe => { + start( + cmd, + Session::tcp_connect_timeout(SessionInfo::from_stdin()?, timeout).await?, + timeout, + None, + ) + .await + } + SessionInput::Lsp => { + let mut data = + LspData::from_buf_reader(&mut std::io::stdin().lock()).map_err(io::Error::from)?; + let info = data.take_session_info().map_err(io::Error::from)?; + start( + cmd, + Session::tcp_connect_timeout(info, timeout).await?, + timeout, + Some(data), + ) + .await + } + #[cfg(unix)] + SessionInput::Socket => { + let path = cmd.session_data.session_socket.clone(); + start( + cmd, + Session::unix_connect_timeout(path, None, timeout).await?, + timeout, + None, + ) + .await + } + } +} + +async fn start( + cmd: ActionSubcommand, + mut session: Session, + timeout: Duration, + lsp_data: Option, +) -> Result<(), Error> +where + T: DataStream + 'static, +{ + // TODO: Because lsp is being handled in a separate action, we should fail if we get + // a session type of lsp for a regular action + match (cmd.interactive, cmd.operation) { + // ProcRun request is specially handled and we ignore interactive as + // the stdin will be used for sending ProcStdin to remote process + (_, Some(RequestData::ProcRun { cmd, args })) => { + let mut proc = RemoteProcess::spawn(client::new_tenant(), session, cmd, args).await?; + + // If we also parsed an LSP's initialize request for its session, we want to forward + // it along in the case of a process call + if let Some(data) = lsp_data { + proc.stdin.as_mut().unwrap().write(data.to_string()).await?; + } + + // Now, map the remote process' stdin/stdout/stderr to our own process + let link = RemoteProcessLink::from_remote_pipes( + proc.stdin.take().unwrap(), + proc.stdout.take().unwrap(), + proc.stderr.take().unwrap(), + ); + + let (success, exit_code) = proc.wait().await?; + + // Shut down our link + link.shutdown().await; + + if !success { + if let Some(code) = exit_code { + return Err(Error::BadProcessExit(code)); + } else { + return Err(Error::BadProcessExit(1)); + } + } + + Ok(()) + } + + // All other requests without interactive are oneoffs + (false, Some(data)) => { + let res = session + .send_timeout(Request::new(client::new_tenant(), vec![data]), timeout) + .await?; + ResponseOut::new(cmd.format, res)?.print(); + Ok(()) + } + + // Interactive mode will send an optional first request and then continue + // to read stdin to send more + (true, maybe_req) => { + // Send our first request if provided + if let Some(data) = maybe_req { + let res = session + .send_timeout(Request::new(client::new_tenant(), vec![data]), timeout) + .await?; + ResponseOut::new(cmd.format, res)?.print(); + } + + // Enter into CLI session where we receive requests on stdin and send out + // over stdout/stderr + let cli_session = CliSession::new(client::new_tenant(), session, cmd.format); + cli_session.wait().await?; + + Ok(()) + } + + // Not interactive and no operation given + (false, None) => Err(Error::MissingOperation), + } +} diff --git a/src/cli/subcommand/action/inner.rs b/src/cli/subcommand/action/inner.rs deleted file mode 100644 index 37b4867..0000000 --- a/src/cli/subcommand/action/inner.rs +++ /dev/null @@ -1,182 +0,0 @@ -use crate::{ - cli::opt::Format, - core::{ - constants::MAX_PIPE_CHUNK_SIZE, - data::{Error, Request, RequestData, Response, ResponseData}, - net::{Client, DataStream}, - utils::StringBuf, - }, -}; -use derive_more::IsVariant; -use log::*; -use std::marker::Unpin; -use structopt::StructOpt; -use tokio::{ - io::{self, AsyncRead, AsyncWrite}, - sync::{ - mpsc, - oneshot::{self, error::TryRecvError}, - }, -}; -use tokio_stream::StreamExt; - -#[derive(Copy, Clone, PartialEq, Eq, IsVariant)] -pub enum LoopConfig { - Json, - Proc { id: usize }, - Shell, -} - -impl From for Format { - fn from(config: LoopConfig) -> Self { - match config { - LoopConfig::Json => Self::Json, - LoopConfig::Proc { .. } | LoopConfig::Shell => Self::Shell, - } - } -} - -/// Starts a new action loop that processes requests and receives responses -/// -/// id represents the id of a remote process -pub async fn interactive_loop( - mut client: Client, - tenant: String, - config: LoopConfig, -) -> io::Result<()> -where - T: AsyncRead + AsyncWrite + DataStream + Unpin + 'static, -{ - let mut stream = client.to_response_broadcast_stream(); - - // Create a channel that can report when we should stop the loop based on a received request - let (tx_stop, mut rx_stop) = oneshot::channel::<()>(); - - // We also want to spawn a task to handle sending stdin to the remote process - let mut rx = spawn_stdin_reader(); - tokio::spawn(async move { - let mut buf = StringBuf::new(); - - while let Some(data) = rx.recv().await { - match config { - // Special exit condition for interactive format - _ if buf.trim() == "exit" => { - if let Err(_) = tx_stop.send(()) { - error!("Failed to close interactive loop!"); - } - break; - } - - // For json format, all stdin is treated as individual requests - LoopConfig::Json => { - buf.push_str(&data); - let (lines, new_buf) = buf.into_full_lines(); - buf = new_buf; - - // For each complete line, parse it as json and - if let Some(lines) = lines { - for data in lines.lines() { - debug!("Client sending request: {:?}", data); - let result = serde_json::from_str(&data) - .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)); - match result { - Ok(req) => match client.send(req).await { - Ok(res) => match format_response(Format::Json, res) { - Ok(out) => out.print(), - Err(x) => error!("Failed to format response: {}", x), - }, - Err(x) => { - error!("Failed to send request: {}", x) - } - }, - Err(x) => { - error!("Failed to serialize request ('{}'): {}", data, x); - } - } - } - } - } - - // For interactive shell format, parse stdin as individual commands - LoopConfig::Shell => { - buf.push_str(&data); - let (lines, new_buf) = buf.into_full_lines(); - buf = new_buf; - - if let Some(lines) = lines { - for data in lines.lines() { - trace!("Shell processing line: {:?}", data); - if data.trim().is_empty() { - continue; - } - - debug!("Client sending command: {:?}", data); - - // NOTE: We have to stick something in as the first argument as clap/structopt - // expect the binary name as the first item in the iterator - let result = RequestData::from_iter_safe( - std::iter::once("distant") - .chain(data.trim().split(' ').filter(|s| !s.trim().is_empty())), - ); - match result { - Ok(data) => { - match client - .send(Request::new(tenant.as_str(), vec![data])) - .await - { - Ok(res) => match format_response(Format::Shell, res) { - Ok(out) => out.print(), - Err(x) => error!("Failed to format response: {}", x), - }, - Err(x) => { - error!("Failed to send request: {}", x) - } - } - } - Err(x) => { - error!("Failed to parse command: {}", x); - } - } - } - } - } - - // For non-interactive shell format, all stdin is treated as a proc's stdin - LoopConfig::Proc { id } => { - debug!("Client sending stdin: {:?}", data); - let req = - Request::new(tenant.as_str(), vec![RequestData::ProcStdin { id, data }]); - let result = client.send(req).await; - - if let Err(x) = result { - error!("Failed to send stdin to remote process ({}): {}", id, x); - } - } - } - } - }); - - while let Err(TryRecvError::Empty) = rx_stop.try_recv() { - if let Some(res) = stream.next().await { - let res = res.map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x))?; - - // NOTE: If the loop is for a proxy process, we should assume that the payload - // is all-or-nothing for the done check - let done = config.is_proc() && res.payload.iter().any(|x| x.is_proc_done()); - - format_response(config.into(), res)?.print(); - - // If we aren't interactive but are just running a proc and - // we've received the end of the proc, we should exit - if done { - break; - } - - // If we have nothing else in our stream, we should also exit - } else { - break; - } - } - - Ok(()) -} diff --git a/src/cli/subcommand/action/mod.rs b/src/cli/subcommand/action/mod.rs deleted file mode 100644 index e3bc717..0000000 --- a/src/cli/subcommand/action/mod.rs +++ /dev/null @@ -1,221 +0,0 @@ -use crate::{ - cli::{ - opt::{ActionSubcommand, CommonOpt, Format, SessionInput}, - ExitCode, ExitCodeError, - }, - core::{ - client::{LspData, Session, SessionInfo, SessionInfoFile}, - data::{Request, RequestData, ResponseData}, - net::{DataStream, TransportError}, - }, -}; -use derive_more::{Display, Error, From}; -use log::*; -use tokio::{io, time::Duration}; - -pub(crate) mod inner; - -#[derive(Debug, Display, Error, From)] -pub enum Error { - IoError(io::Error), - TransportError(TransportError), - - #[display(fmt = "Non-interactive but no operation supplied")] - MissingOperation, -} - -impl ExitCodeError for Error { - fn to_exit_code(&self) -> ExitCode { - match self { - Self::IoError(x) => x.to_exit_code(), - Self::TransportError(x) => x.to_exit_code(), - Self::MissingOperation => ExitCode::Usage, - } - } -} - -pub fn run(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> { - let rt = tokio::runtime::Runtime::new()?; - - rt.block_on(async { run_async(cmd, opt).await }) -} - -async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> { - let timeout = opt.to_timeout_duration(); - - match cmd.session { - SessionInput::Environment => { - start( - cmd, - Session::tcp_connect_timeout(SessionInfo::from_environment()?, timeout).await?, - timeout, - None, - ) - .await - } - SessionInput::File => { - let path = cmd.session_data.session_file.clone(); - start( - cmd, - Session::tcp_connect_timeout( - SessionInfoFile::load_from(path).await?.into(), - timeout, - ) - .await?, - timeout, - None, - ) - .await - } - SessionInput::Pipe => { - start( - cmd, - Session::tcp_connect_timeout(SessionInfo::from_stdin()?, timeout).await?, - timeout, - None, - ) - .await - } - SessionInput::Lsp => { - let mut data = - LspData::from_buf_reader(&mut std::io::stdin().lock()).map_err(io::Error::from)?; - let info = data.take_session_info().map_err(io::Error::from)?; - start( - cmd, - Session::tcp_connect_timeout(info, timeout).await?, - timeout, - Some(data), - ) - .await - } - #[cfg(unix)] - SessionInput::Socket => { - let path = cmd.session_data.session_socket.clone(); - start( - cmd, - Session::unix_connect_timeout(path, None, timeout).await?, - timeout, - None, - ) - .await - } - } -} - -async fn start( - cmd: ActionSubcommand, - mut session: Session, - timeout: Duration, - lsp_data: Option, -) -> Result<(), Error> -where - T: DataStream + 'static, -{ - // TODO: Because lsp is being handled in a separate action, we should fail if we get - // a session type of lsp for a regular action - match (cmd.interactive, cmd.operation) { - // ProcRun request is specially handled and we ignore interactive as - // the stdin will be used for sending ProcStdin to remote process - (_, Some(RequestData::ProcRun { cmd, args })) => {} - - // All other requests without interactive are oneoffs - (false, Some(req)) => { - let res = session.send_timeout(req, timeout).await?; - } - - // Interactive mode will send an optional first request and then continue - // to read stdin to send more - (true, maybe_req) => {} - - // Not interactive and no operation given - (false, None) => Err(Error::MissingOperation), - } - - // 1. Determine what type of engagement we're doing - // a. Oneoff connection, request, response - // b. ProcRun where we take over stdin, stdout, stderr to provide a remote - // process experience - // c. Lsp where we do the ProcRun stuff, but translate stdin before sending and - // stdout before outputting - // d. Interactive program - // - // 2. If we have a queued up operation, we need to perform it - // a. For oneoff, this is the request of the oneoff - // b. For Procrun, this is the request that starts the process - // c. For Lsp, this is the request that starts the process - // d. For interactive, this is an optional first request - // - // 3. If we are using LSP session mode, then we want to send the - // ProcStdin request after our optional queued up operation - // a. For oneoff, this doesn't make sense and we should fail - // b. For ProcRun, we do this after the ProcStart - // c. For Lsp, we do this after the ProcStart - // d. For interactive, this doesn't make sense as we only support - // JSON and shell command input, not LSP input, so this would - // fail and we should fail early - // - // ** LSP would be its own action, which means we want to abstract the logic that feeds - // into this start method such that it can also be used with lsp action - - // Make up a tenant name - let tenant = utils::new_tenant(); - - // Special conditions for continuing to process responses - let mut is_proc_req = false; - let mut proc_id = 0; - - if let Some(req) = cmd - .operation - .map(|payload| Request::new(tenant.as_str(), vec![payload])) - { - // NOTE: We know that there is a single payload entry, so it's all-or-nothing - is_proc_req = req.payload.iter().any(|x| x.is_proc_run()); - - debug!("Client sending request: {:?}", req); - let res = session.send_timeout(req, timeout).await?; - - // Store the spawned process id for using in sending stdin (if we spawned a proc) - // NOTE: We can assume that there is a single payload entry in response to our single - // entry in our request - if let Some(ResponseData::ProcStart { id }) = res.payload.first() { - proc_id = *id; - } - - inner::format_response(cmd.format, res)?.print(); - - // If we also parsed an LSP's initialize request for its session, we want to forward - // it along in the case of a process call - // - // TODO: Do we need to do this somewhere else to apply to all possible ways an LSP - // could be started? - if let Some(data) = lsp_data { - session - .fire_timeout( - Request::new( - tenant.as_str(), - vec![RequestData::ProcStdin { - id: proc_id, - data: data.to_string(), - }], - ), - timeout, - ) - .await?; - } - } - - // If we are executing a process, we want to continue interacting via stdin and receiving - // results via stdout/stderr - // - // If we are interactive, we want to continue looping regardless - if is_proc_req || cmd.interactive { - let config = match cmd.format { - Format::Json => inner::LoopConfig::Json, - Format::Shell if cmd.interactive => inner::LoopConfig::Shell, - Format::Shell => inner::LoopConfig::Proc { id: proc_id }, - }; - inner::interactive_loop(client, tenant, config).await?; - } - - Ok(()) -} diff --git a/src/cli/subcommand/launch.rs b/src/cli/subcommand/launch.rs index b17f834..7609485 100644 --- a/src/cli/subcommand/launch.rs +++ b/src/cli/subcommand/launch.rs @@ -1,27 +1,18 @@ use crate::{ cli::{ opt::{CommonOpt, Format, LaunchSubcommand, SessionOutput}, - ExitCode, ExitCodeError, + CliSession, ExitCode, ExitCodeError, }, core::{ - constants::CLIENT_BROADCAST_CHANNEL_CAPACITY, - data::{Request, RequestData, Response, ResponseData}, - net::{Client, Transport, TransportReadHalf, TransportWriteHalf}, - session::{Session, SessionFile}, - utils, + client::{self, Session, SessionInfo, SessionInfoFile}, + server::RelayServer, }, }; use derive_more::{Display, Error, From}; use fork::{daemon, Fork}; use log::*; -use std::{marker::Unpin, path::Path, string::FromUtf8Error, sync::Arc}; -use tokio::{ - io::{self, AsyncRead, AsyncWrite}, - process::Command, - runtime::{Handle, Runtime}, - sync::{broadcast, mpsc, oneshot, Mutex}, - time::Duration, -}; +use std::{path::Path, string::FromUtf8Error}; +use tokio::{io, process::Command, runtime::Runtime, time::Duration}; #[derive(Debug, Display, Error, From)] pub enum Error { @@ -44,12 +35,6 @@ impl ExitCodeError for Error { } } -/// Represents state associated with a connection -#[derive(Default)] -struct ConnState { - processes: Vec, -} - pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> { let rt = Runtime::new()?; let session_output = cmd.session; @@ -68,7 +53,7 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> { match session_output { SessionOutput::File => { debug!("Outputting session to {:?}", session_file); - rt.block_on(async { SessionFile::new(session_file, session).save().await })? + rt.block_on(async { SessionInfoFile::new(session_file, session).save().await })? } SessionOutput::Keep => { debug!("Entering interactive loop over stdin"); @@ -139,54 +124,27 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> { Ok(()) } -async fn keep_loop(session: Session, format: Format, duration: Duration) -> io::Result<()> { - use crate::cli::subcommand::action::inner; - match Client::tcp_connect_timeout(session, duration).await { - Ok(client) => { - let config = match format { - Format::Json => inner::LoopConfig::Json, - Format::Shell => inner::LoopConfig::Shell, - }; - inner::interactive_loop(client, utils::new_tenant(), config).await +async fn keep_loop(info: SessionInfo, format: Format, duration: Duration) -> io::Result<()> { + match Session::tcp_connect_timeout(info, duration).await { + Ok(session) => { + let cli_session = CliSession::new(client::new_tenant(), session, format); + cli_session.wait().await } Err(x) => Err(x), } } -#[cfg(unix)] async fn socket_loop( socket_path: impl AsRef, - session: Session, + info: SessionInfo, duration: Duration, fail_if_socket_exists: bool, shutdown_after: Option, ) -> io::Result<()> { // We need to form a connection with the actual server to forward requests // and responses between connections - debug!("Connecting to {} {}", session.host, session.port); - let mut client = Client::tcp_connect_timeout(session, duration).await?; - - // Get a copy of our client's broadcaster so we can have each connection - // subscribe to it for new messages filtered by tenant - debug!("Acquiring client broadcaster"); - let broadcaster = client.to_response_broadcaster(); - - // Spawn task to send to the server requests from connections - debug!("Spawning request forwarding task"); - let (req_tx, mut req_rx) = mpsc::channel::(CLIENT_BROADCAST_CHANNEL_CAPACITY); - tokio::spawn(async move { - while let Some(req) = req_rx.recv().await { - debug!( - "Forwarding request of type{} {} to server", - if req.payload.len() > 1 { "s" } else { "" }, - req.to_payload_type_string() - ); - if let Err(x) = client.fire_timeout(req, duration).await { - error!("Client failed to send request: {:?}", x); - break; - } - } - }); + debug!("Connecting to {} {}", info.host, info.port); + let session = Session::tcp_connect_timeout(info, duration).await?; // Remove the socket file if it already exists if !fail_if_socket_exists && socket_path.as_ref().exists() { @@ -199,205 +157,17 @@ async fn socket_loop( debug!("Binding to unix socket: {:?}", socket_path.as_ref()); let listener = tokio::net::UnixListener::bind(socket_path)?; - let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after); - - loop { - tokio::select! { - result = listener.accept() => {match result { - Ok((conn, _)) => { - // Create a unique id to associate with the connection since its address - // is not guaranteed to have an identifiable string - let conn_id: usize = rand::random(); - - // Establish a proper connection via a handshake, discarding the connection otherwise - let transport = match Transport::from_handshake(conn, None).await { - Ok(transport) => transport, - Err(x) => { - error!(" Failed handshake: {}", conn_id, x); - continue; - } - }; - let (t_read, t_write) = transport.into_split(); - - // Used to alert our response task of the connection's tenant name - // based on the first - let (tenant_tx, tenant_rx) = oneshot::channel(); - - // Create a state we use to keep track of connection-specific data - debug!(" Initializing internal state", conn_id); - let state = Arc::new(Mutex::new(ConnState::default())); - - // Spawn task to continually receive responses from the client that - // may or may not be relevant to the connection, which will filter - // by tenant and then along any response that matches - let res_rx = broadcaster.subscribe(); - let state_2 = Arc::clone(&state); - tokio::spawn(async move { - handle_conn_outgoing(conn_id, state_2, t_write, tenant_rx, res_rx).await; - }); - - // Spawn task to continually read requests from connection and forward - // them along to be sent via the client - let req_tx = req_tx.clone(); - let ct_2 = Arc::clone(&ct); - tokio::spawn(async move { - ct_2.lock().await.increment(); - handle_conn_incoming(conn_id, state, t_read, tenant_tx, req_tx).await; - ct_2.lock().await.decrement(); - debug!(" Disconnected", conn_id); - }); - } - Err(x) => { - error!("Listener failed: {}", x); - break; - } - }} - _ = notify.notified() => { - warn!("Reached shutdown timeout, so terminating"); - break; - } - } - } - - Ok(()) -} - -/// Conn::Request -> Client::Fire -async fn handle_conn_incoming( - conn_id: usize, - state: Arc>, - mut reader: TransportReadHalf, - tenant_tx: oneshot::Sender, - req_tx: mpsc::Sender, -) where - T: AsyncRead + Unpin, -{ - macro_rules! process_req { - ($on_success:expr; $done:expr) => { - match reader.receive::().await { - Ok(Some(req)) => { - $on_success(&req); - if let Err(x) = req_tx.send(req).await { - error!( - "Failed to pass along request received on unix socket: {:?}", - x - ); - $done; - } - } - Ok(None) => $done, - Err(x) => { - error!("Failed to receive request from unix stream: {:?}", x); - $done; - } - } - }; - } - - let mut tenant = None; - - // NOTE: Have to acquire our first request outside our loop since the oneshot - // sender of the tenant's name is consuming - process_req!( - |req: &Request| { - tenant = Some(req.tenant.clone()); - if let Err(x) = tenant_tx.send(req.tenant.clone()) { - error!("Failed to send along acquired tenant name: {:?}", x); - return; - } - }; - return - ); - - // Loop and process all additional requests - loop { - process_req!(|_| {}; break); - } - - // At this point, we have processed at least one request successfully - // and should have the tenant populated. If we had a failure at the - // beginning, we exit the function early via return. - let tenant = tenant.unwrap(); - - // Perform cleanup if done by sending a request to kill each running process - // debug!("Cleaning conn {} :: killing process {}", conn_id, id); - if let Err(x) = req_tx - .send(Request::new( - tenant.clone(), - state - .lock() - .await - .processes - .iter() - .map(|id| RequestData::ProcKill { id: *id }) - .collect(), - )) + let server = RelayServer::initialize(session, listener, shutdown_after).await?; + server + .wait() .await - { - error!(" Failed to send kill signals: {}", conn_id, x); - } -} - -async fn handle_conn_outgoing( - conn_id: usize, - state: Arc>, - mut writer: TransportWriteHalf, - tenant_rx: oneshot::Receiver, - mut res_rx: broadcast::Receiver, -) where - T: AsyncWrite + Unpin, -{ - // We wait for the tenant to be identified by the first request - // before processing responses to be sent back; this is easier - // to implement and yields the same result as we would be dropping - // all responses before we know the tenant - if let Ok(tenant) = tenant_rx.await { - debug!("Associated tenant {} with conn {}", tenant, conn_id); - loop { - match res_rx.recv().await { - // Forward along responses that are for our connection - Ok(res) if res.tenant == tenant => { - debug!( - "Conn {} being sent response of type{} {}", - conn_id, - if res.payload.len() > 1 { "s" } else { "" }, - res.to_payload_type_string(), - ); - - // If a new process was started, we want to capture the id and - // associate it with the connection - let ids = res.payload.iter().filter_map(|x| match x { - ResponseData::ProcStart { id } => Some(*id), - _ => None, - }); - for id in ids { - debug!("Tracking proc {} for conn {}", id, conn_id); - state.lock().await.processes.push(id); - } - - if let Err(x) = writer.send(res).await { - error!("Failed to send response through unix connection: {}", x); - break; - } - } - // Skip responses that are not for our connection - Ok(_) => {} - Err(x) => { - error!( - "Conn {} failed to receive broadcast response: {}", - conn_id, x - ); - break; - } - } - } - } + .map_err(|x| io::Error::new(io::ErrorKind::Other, x)) } /// Spawns a remote server that listens for requests /// /// Returns the session associated with the server -async fn spawn_remote_server(cmd: LaunchSubcommand, _opt: CommonOpt) -> Result { +async fn spawn_remote_server(cmd: LaunchSubcommand, _opt: CommonOpt) -> Result { let distant_command = format!( "{} listen --daemon --host {} {}", cmd.distant, @@ -417,6 +187,7 @@ async fn spawn_remote_server(cmd: LaunchSubcommand, _opt: CommonOpt) -> Result Result().ok()) + .find_map(|line| line.parse::().ok()) .ok_or(Error::MissingSessionData)?; - session.host = cmd.host; + info.host = cmd.host; - Ok(session) + Ok(info) } diff --git a/src/cli/subcommand/listen.rs b/src/cli/subcommand/listen.rs index 12149c2..653aabe 100644 --- a/src/cli/subcommand/listen.rs +++ b/src/cli/subcommand/listen.rs @@ -68,7 +68,7 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R let server = DistantServer::bind( addr, cmd.port, - cmd.to_shutdown_after_duration(), + shutdown_after, cmd.max_msg_capacity as usize, ) .await?; diff --git a/src/core/client/lsp/data.rs b/src/core/client/lsp/data.rs index c04046c..d88ad0a 100644 --- a/src/core/client/lsp/data.rs +++ b/src/core/client/lsp/data.rs @@ -297,7 +297,7 @@ impl FromStr for LspHeader { #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct LspContent(Map); -fn for_each_mut_string(value: &mut Value, check: F1, mutate: F2) +fn for_each_mut_string(value: &mut Value, check: &F1, mutate: &mut F2) where F1: Fn(&String) -> bool, F2: FnMut(&mut String), @@ -309,12 +309,15 @@ where .for_each(|v| for_each_mut_string(v, check, mutate)); // Mutate keys if necessary - for key in obj.keys() { - if check(key) { - if let Some((key, value)) = obj.remove_entry(key) { - mutate(&mut key); - obj.insert(key, value); - } + let keys: Vec = obj + .keys() + .filter(|k| check(k)) + .map(ToString::to_string) + .collect(); + for key in keys { + if let Some((mut key, value)) = obj.remove_entry(&key) { + mutate(&mut key); + obj.insert(key, value); } } } @@ -328,7 +331,7 @@ where fn swap_prefix(obj: &mut Map, old: &str, new: &str) { let check = |s: &String| s.starts_with(old); - let mutate = |s: &mut String| { + let mut mutate = |s: &mut String| { if let Some(pos) = s.find(old) { s.replace_range(pos..old.len(), new); } @@ -336,15 +339,18 @@ fn swap_prefix(obj: &mut Map, old: &str, new: &str) { // Mutate values obj.values_mut() - .for_each(|v| for_each_mut_string(v, check, mutate)); + .for_each(|v| for_each_mut_string(v, &check, &mut mutate)); // Mutate keys if necessary - for key in obj.keys() { - if check(key) { - if let Some((key, value)) = obj.remove_entry(key) { - mutate(&mut key); - obj.insert(key, value); - } + let keys: Vec = obj + .keys() + .filter(|k| check(k)) + .map(ToString::to_string) + .collect(); + for key in keys { + if let Some((mut key, value)) = obj.remove_entry(&key) { + mutate(&mut key); + obj.insert(key, value); } } } @@ -528,7 +534,11 @@ mod tests { fn data_from_buf_reader_should_fail_if_reach_eof_before_received_full_data() { // No line termination let err = LspData::from_buf_reader(&mut io::Cursor::new("Content-Length: 22")).unwrap_err(); - assert!(matches!(err, LspDataParseError::UnexpectedEof), "{:?}", err); + assert!( + matches!(err, LspDataParseError::BadHeaderTermination), + "{:?}", + err + ); // Header doesn't finish let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!( @@ -1100,7 +1110,7 @@ mod tests { #[test] fn content_convert_distant_scheme_to_local_should_convert_keys_and_values() { - let content = LspContent(make_obj!({ + let mut content = LspContent(make_obj!({ "distant://key1": "file://value1", "file://key2": "distant://value2", "key3": ["file://value3", "distant://value4"], diff --git a/src/core/client/lsp/mod.rs b/src/core/client/lsp/mod.rs index 61ebb2f..ac82fc9 100644 --- a/src/core/client/lsp/mod.rs +++ b/src/core/client/lsp/mod.rs @@ -21,6 +21,7 @@ impl RemoteLspProcess { /// Spawns the specified process on the remote machine using the given session, treating /// the process like an LSP server pub async fn spawn( + tenant: String, session: Session, cmd: String, args: Vec, @@ -28,7 +29,7 @@ impl RemoteLspProcess { where T: DataStream + 'static, { - let mut inner = RemoteProcess::spawn(session, cmd, args).await?; + let mut inner = RemoteProcess::spawn(tenant, session, cmd, args).await?; let stdin = inner.stdin.take().map(RemoteLspStdin::new); let stdout = inner.stdout.take().map(RemoteLspStdout::new); let stderr = inner.stderr.take().map(RemoteLspStderr::new); diff --git a/src/core/client/mod.rs b/src/core/client/mod.rs index 5b2c906..da2cfe4 100644 --- a/src/core/client/mod.rs +++ b/src/core/client/mod.rs @@ -3,14 +3,7 @@ mod process; mod session; mod utils; -// TODO: Make wrappers around a connection to facilitate the types -// of engagements -// -// 1. Command -> Single request/response through a future -// 2. Proxy -> Does proc-run and waits until proc-done received, -// exposing a sender for stdin and receivers for stdout/stderr, -// and supporting a future await for completion with exit code -// 3. pub use lsp::*; pub use process::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout}; pub use session::*; +pub(crate) use utils::new_tenant; diff --git a/src/core/client/process.rs b/src/core/client/process.rs index 29b43ad..babc746 100644 --- a/src/core/client/process.rs +++ b/src/core/client/process.rs @@ -1,5 +1,5 @@ use crate::core::{ - client::{utils, Session}, + client::Session, constants::CLIENT_BROADCAST_CHANNEL_CAPACITY, data::{Request, RequestData, Response, ResponseData}, net::{DataStream, TransportError}, @@ -63,6 +63,7 @@ pub struct RemoteProcess { impl RemoteProcess { /// Spawns the specified process on the remote machine using the given session pub async fn spawn( + tenant: String, mut session: Session, cmd: String, args: Vec, @@ -70,8 +71,6 @@ impl RemoteProcess { where T: DataStream + 'static, { - let tenant = utils::new_tenant(); - // Submit our run request and wait for a response let res = session .send(Request::new( @@ -127,7 +126,10 @@ impl RemoteProcess { /// Waits for the process to terminate, returning the success status and an optional exit code pub async fn wait(self) -> Result<(bool, Option), RemoteProcessError> { - self.res_task.await? + match tokio::try_join!(self.req_task, self.res_task) { + Ok((_, res)) => res, + Err(x) => Err(RemoteProcessError::from(x)), + } } /// Aborts the process by forcing its response task to shutdown, which means that a call diff --git a/src/core/client/session/mod.rs b/src/core/client/session/mod.rs index 7258f73..d43a4d5 100644 --- a/src/core/client/session/mod.rs +++ b/src/core/client/session/mod.rs @@ -50,7 +50,7 @@ where impl Session { /// Creates a session around an inmemory transport pub async fn from_inmemory_transport(transport: Transport) -> io::Result { - Self::inner_connect(transport).await + Self::initialize(transport).await } } @@ -67,7 +67,7 @@ impl Session { .map(|x| x.to_string()) .unwrap_or_else(|_| String::from("???")) ); - Self::inner_connect(transport).await + Self::initialize(transport).await } /// Connect to a remote TCP server, timing out after duration has passed @@ -93,7 +93,7 @@ impl Session { .map(|x| format!("{:?}", x)) .unwrap_or_else(|_| String::from("???")) ); - Self::inner_connect(transport).await + Self::initialize(transport).await } /// Connect to a proxy unix socket, timing out after duration has passed @@ -112,8 +112,8 @@ impl Session where T: DataStream, { - /// Establishes a connection using the provided transport - async fn inner_connect(transport: Transport) -> io::Result { + /// Initializes a session using the provided transport + pub async fn initialize(transport: Transport) -> io::Result { let (mut t_read, t_write) = transport.into_split(); let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new())); let (broadcast, init_broadcast_receiver) = @@ -243,14 +243,13 @@ mod tests { use crate::core::{ constants::test::TENANT, data::{RequestData, ResponseData}, - net::test::make_transport_pair, }; use std::time::Duration; #[tokio::test] async fn send_should_wait_until_response_received() { - let (t1, mut t2) = make_transport_pair(); - let mut session = Session::inner_connect(t1).await.unwrap(); + let (t1, mut t2) = Transport::make_pair(); + let mut session = Session::initialize(t1).await.unwrap(); let req = Request::new(TENANT, vec![RequestData::ProcList {}]); let res = Response::new( @@ -270,8 +269,8 @@ mod tests { #[tokio::test] async fn send_timeout_should_fail_if_response_not_received_in_time() { - let (t1, mut t2) = make_transport_pair(); - let mut session = Session::inner_connect(t1).await.unwrap(); + let (t1, mut t2) = Transport::make_pair(); + let mut session = Session::initialize(t1).await.unwrap(); let req = Request::new(TENANT, vec![RequestData::ProcList {}]); match session.send_timeout(req, Duration::from_millis(30)).await { @@ -285,8 +284,8 @@ mod tests { #[tokio::test] async fn fire_should_send_request_and_not_wait_for_response() { - let (t1, mut t2) = make_transport_pair(); - let mut session = Session::inner_connect(t1).await.unwrap(); + let (t1, mut t2) = Transport::make_pair(); + let mut session = Session::initialize(t1).await.unwrap(); let req = Request::new(TENANT, vec![RequestData::ProcList {}]); match session.fire(req).await { diff --git a/src/core/net/listener.rs b/src/core/net/listener.rs new file mode 100644 index 0000000..5c1795b --- /dev/null +++ b/src/core/net/listener.rs @@ -0,0 +1,48 @@ +use super::DataStream; +use std::{future::Future, pin::Pin}; +use tokio::{ + io, + net::{TcpListener, TcpStream}, +}; + +/// Represents a type that has a listen interface +pub trait Listener: Send + Sync { + type Conn: DataStream; + + /// Async function that accepts a new connection, returning `Ok(Self::Conn)` + /// upon receiving the next connection + fn accept<'a>(&'a self) -> Pin> + Send + 'a>> + where + Self: Sync + 'a; +} + +impl Listener for TcpListener { + type Conn = TcpStream; + + fn accept<'a>(&'a self) -> Pin> + Send + 'a>> + where + Self: Sync + 'a, + { + async fn accept(_self: &TcpListener) -> io::Result { + _self.accept().await.map(|(stream, _)| stream) + } + + Box::pin(accept(self)) + } +} + +#[cfg(unix)] +impl Listener for tokio::net::UnixListener { + type Conn = tokio::net::UnixStream; + + fn accept<'a>(&'a self) -> Pin> + Send + 'a>> + where + Self: Sync + 'a, + { + async fn accept(_self: &tokio::net::UnixListener) -> io::Result { + _self.accept().await.map(|(stream, _)| stream) + } + + Box::pin(accept(self)) + } +} diff --git a/src/core/net/mod.rs b/src/core/net/mod.rs index e40dbf1..2499f31 100644 --- a/src/core/net/mod.rs +++ b/src/core/net/mod.rs @@ -1,4 +1,7 @@ +mod listener; mod transport; + +pub use listener::Listener; pub use transport::*; // Re-export commonly-used orion structs diff --git a/src/core/net/transport/inmemory.rs b/src/core/net/transport/inmemory.rs index 3a3e15d..ee41a6b 100644 --- a/src/core/net/transport/inmemory.rs +++ b/src/core/net/transport/inmemory.rs @@ -1,6 +1,7 @@ -use super::DataStream; +use super::{DataStream, SecretKey, Transport}; use std::{ pin::Pin, + sync::Arc, task::{Context, Poll}, }; use tokio::{ @@ -118,10 +119,146 @@ impl DataStream for InmemoryStream { type Write = InmemoryStreamWriteHalf; fn to_connection_tag(&self) -> String { - String::from("test-stream") + String::from("inmemory-stream") } fn into_split(self) -> (Self::Read, Self::Write) { (self.incoming, self.outgoing) } } + +impl Transport { + /// Produces a pair of inmemory transports that are connected to each other with matching + /// auth and encryption keys + /// + /// Sets the buffer for message passing for each underlying stream to the given buffer size + pub fn pair(buffer: usize) -> (Transport, Transport) { + let auth_key = Arc::new(SecretKey::default()); + let crypt_key = Arc::new(SecretKey::default()); + + let (a, b) = InmemoryStream::pair(buffer); + let a = Transport::new(a, Some(Arc::clone(&auth_key)), Arc::clone(&crypt_key)); + let b = Transport::new(b, Some(auth_key), crypt_key); + (a, b) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[test] + fn to_connection_tag_should_be_hardcoded_string() { + let (_, _, stream) = InmemoryStream::make(1); + assert_eq!(stream.to_connection_tag(), "inmemory-stream"); + } + + #[tokio::test] + async fn make_should_return_sender_that_sends_data_to_stream() { + let (tx, _, mut stream) = InmemoryStream::make(3); + + tx.send(b"test msg 1".to_vec()).await.unwrap(); + tx.send(b"test msg 2".to_vec()).await.unwrap(); + tx.send(b"test msg 3".to_vec()).await.unwrap(); + + // Should get data matching a singular message + let mut buf = [0; 256]; + let len = stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 1"); + + // Next call would get the second message + let len = stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 2"); + + // When the last of the senders is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(tx); + + let len = stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 3"); + + let len = stream.read(&mut buf).await.unwrap(); + assert_eq!(len, 0, "Unexpectedly got more data"); + } + + #[tokio::test] + async fn make_should_return_receiver_that_receives_data_from_stream() { + let (_, mut rx, mut stream) = InmemoryStream::make(3); + + stream.write_all(b"test msg 1").await.unwrap(); + stream.write_all(b"test msg 2").await.unwrap(); + stream.write_all(b"test msg 3").await.unwrap(); + + // Should get data matching a singular message + assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec())); + + // Next call would get the second message + assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec())); + + // When the stream is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(stream); + + assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec())); + + assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); + } + + #[tokio::test] + async fn into_split_should_provide_a_read_half_that_receives_from_sender() { + let (tx, _, stream) = InmemoryStream::make(3); + let (mut read_half, _) = stream.into_split(); + + tx.send(b"test msg 1".to_vec()).await.unwrap(); + tx.send(b"test msg 2".to_vec()).await.unwrap(); + tx.send(b"test msg 3".to_vec()).await.unwrap(); + + // Should get data matching a singular message + let mut buf = [0; 256]; + let len = read_half.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 1"); + + // Next call would get the second message + let len = read_half.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 2"); + + // When the last of the senders is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(tx); + + let len = read_half.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..len], b"test msg 3"); + + let len = read_half.read(&mut buf).await.unwrap(); + assert_eq!(len, 0, "Unexpectedly got more data"); + } + + #[tokio::test] + async fn into_split_should_provide_a_write_half_that_sends_to_receiver() { + let (_, mut rx, stream) = InmemoryStream::make(3); + let (_, mut write_half) = stream.into_split(); + + write_half.write_all(b"test msg 1").await.unwrap(); + write_half.write_all(b"test msg 2").await.unwrap(); + write_half.write_all(b"test msg 3").await.unwrap(); + + // Should get data matching a singular message + assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec())); + + // Next call would get the second message + assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec())); + + // When the stream is dropped, we should still get + // the rest of the data that was sent first before getting + // an indicator that there is no more data + drop(write_half); + + assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec())); + + assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); + } +} diff --git a/src/core/net/transport/mod.rs b/src/core/net/transport/mod.rs index b9d5ccb..a2f4fa2 100644 --- a/src/core/net/transport/mod.rs +++ b/src/core/net/transport/mod.rs @@ -346,31 +346,25 @@ where /// Test utilities #[cfg(test)] -pub mod test { - use super::*; - - use crate::core::constants::test::BUFFER_SIZE; - use crate::net::InmemoryStream; - use orion::aead::SecretKey; - +impl Transport { /// Makes a connected pair of transports with matching crypt keys and using the provided /// auth keys - pub fn make_transport_pair_with_auth_keys( + pub fn make_pair_with_auth_keys( ak1: Option>, ak2: Option>, ) -> (Transport, Transport) { let crypt_key = Arc::new(SecretKey::default()); - let (a, b) = InmemoryStream::pair(BUFFER_SIZE); + let (a, b) = InmemoryStream::pair(crate::core::constants::test::BUFFER_SIZE); let a = Transport::new(a, ak1, Arc::clone(&crypt_key)); let b = Transport::new(b, ak2, crypt_key); (a, b) } /// Makes a connected pair of transports with matching auth and crypt keys - pub fn make_transport_pair() -> (Transport, Transport) { - let auth_key = Arc::new(SecretKey::default()); - make_transport_pair_with_auth_keys(Some(Arc::clone(&auth_key)), Some(auth_key)) + /// using test buffer size + pub fn make_pair() -> (Transport, Transport) { + Self::pair(crate::core::constants::test::BUFFER_SIZE) } } @@ -380,8 +374,6 @@ mod tests { use crate::core::constants::test::BUFFER_SIZE; use std::io; - use test::make_transport_pair_with_auth_keys; - #[tokio::test] async fn transport_from_handshake_should_fail_if_connection_reached_eof() { // Cause nothing left incoming to stream by _ @@ -462,7 +454,7 @@ mod tests { #[tokio::test] async fn transport_should_be_able_to_send_encrypted_data_to_other_side_to_decrypt() { // Make two transports with no auth keys - let (mut src, mut dst) = make_transport_pair_with_auth_keys(None, None); + let (mut src, mut dst) = Transport::make_pair_with_auth_keys(None, None); src.send("some data").await.expect("Failed to send data"); let data = dst @@ -480,7 +472,7 @@ mod tests { // Make two transports with same auth keys let (mut src, mut dst) = - make_transport_pair_with_auth_keys(Some(Arc::clone(&auth_key)), Some(auth_key)); + Transport::make_pair_with_auth_keys(Some(Arc::clone(&auth_key)), Some(auth_key)); src.send("some data").await.expect("Failed to send data"); let data = dst @@ -495,7 +487,7 @@ mod tests { #[tokio::test] async fn transport_receive_should_fail_if_auth_key_differs_from_other_end() { // Make two transports with different auth keys - let (mut src, mut dst) = make_transport_pair_with_auth_keys( + let (mut src, mut dst) = Transport::make_pair_with_auth_keys( Some(Arc::new(SecretKey::default())), Some(Arc::new(SecretKey::default())), ); @@ -511,7 +503,7 @@ mod tests { async fn transport_receive_should_fail_if_has_auth_key_while_sender_did_not_use_one() { // Make two transports with different auth keys let (mut src, mut dst) = - make_transport_pair_with_auth_keys(None, Some(Arc::new(SecretKey::default()))); + Transport::make_pair_with_auth_keys(None, Some(Arc::new(SecretKey::default()))); src.send("some data").await.expect("Failed to send data"); @@ -529,7 +521,7 @@ mod tests { async fn transport_receive_should_fail_if_has_no_auth_key_while_sender_used_one() { // Make two transports with different auth keys let (mut src, mut dst) = - make_transport_pair_with_auth_keys(Some(Arc::new(SecretKey::default())), None); + Transport::make_pair_with_auth_keys(Some(Arc::new(SecretKey::default())), None); src.send("some data").await.expect("Failed to send data"); match dst.receive::().await { diff --git a/src/core/server/distant/handler.rs b/src/core/server/distant/handler.rs index b1aa303..375d574 100644 --- a/src/core/server/distant/handler.rs +++ b/src/core/server/distant/handler.rs @@ -3,14 +3,13 @@ use crate::core::{ data::{ self, DirEntry, FileType, Request, RequestData, Response, ResponseData, RunningProcess, }, - server::state::{Process, State}, + server::distant::state::{Process, State}, }; use derive_more::{Display, Error, From}; use futures::future; use log::*; use std::{ env, - net::SocketAddr, path::{Path, PathBuf}, process::Stdio, sync::Arc, @@ -24,7 +23,7 @@ use tokio::{ use walkdir::WalkDir; pub type Reply = mpsc::Sender; -type HState = Arc>>; +type HState = Arc>; #[derive(Debug, Display, Error, From)] pub enum ServerError { @@ -43,14 +42,14 @@ impl From for ResponseData { /// Processes the provided request, sending replies using the given sender pub(super) async fn process( - addr: SocketAddr, + conn_id: usize, state: HState, req: Request, tx: Reply, ) -> Result<(), mpsc::error::SendError> { async fn inner( tenant: Arc, - addr: SocketAddr, + conn_id: usize, state: HState, data: RequestData, tx: Reply, @@ -76,7 +75,7 @@ pub(super) async fn process( RequestData::Exists { path } => exists(path).await, RequestData::Metadata { path, canonicalize } => metadata(path, canonicalize).await, RequestData::ProcRun { cmd, args } => { - proc_run(tenant.to_string(), addr, state, tx, cmd, args).await + proc_run(tenant.to_string(), conn_id, state, tx, cmd, args).await } RequestData::ProcKill { id } => proc_kill(state, id).await, RequestData::ProcStdin { id, data } => proc_stdin(state, id, data).await, @@ -94,7 +93,7 @@ pub(super) async fn process( let state_2 = Arc::clone(&state); let tx_2 = tx.clone(); payload_tasks.push(tokio::spawn(async move { - match inner(tenant_2, addr, state_2, data, tx_2).await { + match inner(tenant_2, conn_id, state_2, data, tx_2).await { Ok(data) => data, Err(x) => ResponseData::from(x), } @@ -114,8 +113,8 @@ pub(super) async fn process( let res = Response::new(req.tenant, Some(req.id), payload); debug!( - " Sending response of type{} {}", - addr, + " Sending response of type{} {}", + conn_id, if res.payload.len() > 1 { "s" } else { "" }, res.to_payload_type_string() ); @@ -358,7 +357,7 @@ async fn metadata(path: PathBuf, canonicalize: bool) -> Result Sending response of type{} {}", - addr, + " Sending response of type{} {}", + conn_id, if res.payload.len() > 1 { "s" } else { "" }, res.to_payload_type_string() ); @@ -430,8 +429,8 @@ async fn proc_run( vec![ResponseData::ProcStderr { id, data }], ); debug!( - " Sending response of type{} {}", - addr, + " Sending response of type{} {}", + conn_id, if res.payload.len() > 1 { "s" } else { "" }, res.to_payload_type_string() ); @@ -491,8 +490,8 @@ async fn proc_run( vec![ResponseData::ProcDone { id, success, code }] ); debug!( - " Sending response of type{} {}", - addr, + " Sending response of type{} {}", + conn_id, if res.payload.len() > 1 { "s" } else { "" }, res.to_payload_type_string() ); @@ -503,8 +502,8 @@ async fn proc_run( Err(x) => { let res = Response::new(tenant.as_str(), None, vec![ResponseData::from(x)]); debug!( - " Sending response of type{} {}", - addr, + " Sending response of type{} {}", + conn_id, if res.payload.len() > 1 { "s" } else { "" }, res.to_payload_type_string() ); @@ -533,8 +532,8 @@ async fn proc_run( id, success: false, code: None }]); debug!( - " Sending response of type{} {}", - addr, + " Sending response of type{} {}", + conn_id, if res.payload.len() > 1 { "s" } else { "" }, res.to_payload_type_string() ); @@ -556,7 +555,7 @@ async fn proc_run( stdin_tx, kill_tx, }; - state.lock().await.push_process(addr, process); + state.lock().await.push_process(conn_id, process); Ok(ResponseData::ProcStart { id }) } @@ -609,3 +608,676 @@ async fn system_info() -> Result { main_separator: std::path::MAIN_SEPARATOR, }) } + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::Write; + use tempfile::{NamedTempFile, TempDir}; + + fn setup( + buffer: usize, + ) -> ( + usize, + Arc>, + mpsc::Sender, + mpsc::Receiver, + ) { + let (tx, rx) = mpsc::channel(buffer); + ( + rand::random(), + Arc::new(Mutex::new(State::default())), + tx, + rx, + ) + } + + /// Create a temporary path that does not exist + fn temppath() -> PathBuf { + // Deleted when dropped + NamedTempFile::new().unwrap().into_temp_path().to_path_buf() + } + + #[tokio::test] + async fn file_read_should_send_error_if_fails_to_read_file() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a file and then delete it, keeping just its path + let path = temppath(); + + let req = Request::new("test-tenant", vec![RequestData::FileRead { path }]); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Error(_)), + "Unexpected response: {:?}", + res.payload[0] + ); + } + + #[tokio::test] + async fn file_read_should_send_blob_with_file_contents() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a temporary file and fill it with some contents + let mut file = NamedTempFile::new().unwrap(); + file.write_all(b"some file contents").unwrap(); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileRead { + path: file.path().to_path_buf(), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + match &res.payload[0] { + ResponseData::Blob { data } => assert_eq!(data, b"some file contents"), + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn file_read_text_should_send_error_if_fails_to_read_file() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a file and then delete it, keeping just its path + let path = temppath(); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileReadText { path: path }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Error(_)), + "Unexpected response: {:?}", + res.payload[0] + ); + } + + #[tokio::test] + async fn file_read_text_should_send_text_with_file_contents() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a temporary file and fill it with some contents + let mut file = NamedTempFile::new().unwrap(); + file.write_all(b"some file contents").unwrap(); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileReadText { + path: file.path().to_path_buf(), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + match &res.payload[0] { + ResponseData::Text { data } => assert_eq!(data, "some file contents"), + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn file_write_should_send_error_if_fails_to_write_file() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let path = temppath().join("some_file"); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileWrite { + path: path.clone(), + data: b"some text".to_vec(), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Error(_)), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that we didn't actually create the file + assert!(!path.exists(), "File created unexpectedly"); + } + + #[tokio::test] + async fn file_write_should_send_ok_when_successful() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Path should point to a file that does not exist, but all + // other components leading up to it do + let path = temppath(); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileWrite { + path: path.clone(), + data: b"some text".to_vec(), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Ok), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that we actually did create the file + // with the associated contents + assert!(path.exists(), "File not actually created"); + assert_eq!(tokio::fs::read_to_string(path).await.unwrap(), "some text"); + } + + #[tokio::test] + async fn file_write_text_should_send_error_if_fails_to_write_file() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let path = temppath().join("some_file"); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileWriteText { + path: path.clone(), + text: String::from("some text"), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Error(_)), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that we didn't actually create the file + assert!(!path.exists(), "File created unexpectedly"); + } + + #[tokio::test] + async fn file_write_text_should_send_ok_when_successful() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Path should point to a file that does not exist, but all + // other components leading up to it do + let path = temppath(); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileWriteText { + path: path.clone(), + text: String::from("some text"), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Ok), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that we actually did create the file + // with the associated contents + assert!(path.exists(), "File not actually created"); + assert_eq!(tokio::fs::read_to_string(path).await.unwrap(), "some text"); + } + + #[tokio::test] + async fn file_append_should_send_error_if_fails_to_create_file() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let path = temppath().join("some_file"); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileAppend { + path: path.to_path_buf(), + data: b"some extra contents".to_vec(), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Error(_)), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that we didn't actually create the file + assert!(!path.exists(), "File created unexpectedly"); + } + + #[tokio::test] + async fn file_append_should_send_ok_when_successful() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a temporary file and fill it with some contents + let mut file = NamedTempFile::new().unwrap(); + file.write_all(b"some file contents").unwrap(); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileAppend { + path: file.path().to_path_buf(), + data: b"some extra contents".to_vec(), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Ok), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that we actually did append to the file + assert_eq!( + tokio::fs::read_to_string(file.path()).await.unwrap(), + "some file contentssome extra contents" + ); + } + + #[tokio::test] + async fn file_append_text_should_send_error_if_fails_to_create_file() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a temporary path and add to it to ensure that there are + // extra components that don't exist to cause writing to fail + let path = temppath().join("some_file"); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileAppendText { + path: path.to_path_buf(), + text: String::from("some extra contents"), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Error(_)), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that we didn't actually create the file + assert!(!path.exists(), "File created unexpectedly"); + } + + #[tokio::test] + async fn file_append_text_should_send_ok_when_successful() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create a temporary file and fill it with some contents + let mut file = NamedTempFile::new().unwrap(); + file.write_all(b"some file contents").unwrap(); + + let req = Request::new( + "test-tenant", + vec![RequestData::FileAppendText { + path: file.path().to_path_buf(), + text: String::from("some extra contents"), + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Ok), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that we actually did append to the file + assert_eq!( + tokio::fs::read_to_string(file.path()).await.unwrap(), + "some file contentssome extra contents" + ); + } + + #[tokio::test] + async fn dir_read_should_send_error_if_directory_does_not_exist() { + let (conn_id, state, tx, mut rx) = setup(1); + let path = temppath(); + + let req = Request::new( + "test-tenant", + vec![RequestData::DirRead { + path, + depth: 0, + absolute: false, + canonicalize: false, + include_root: false, + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Error(_)), + "Unexpected response: {:?}", + res.payload[0] + ); + } + + // /root/ + // /root/file1 + // /root/sub1/ + // /root/sub1/file2 + async fn setup_dir() -> TempDir { + let root_dir = TempDir::new().unwrap(); + let file1 = root_dir.path().join("file1"); + let sub1 = root_dir.path().join("sub1"); + let file2 = sub1.join("file2"); + + tokio::fs::write(file1.as_path(), "").await.unwrap(); + tokio::fs::create_dir(sub1.as_path()).await.unwrap(); + tokio::fs::write(file2.as_path(), "").await.unwrap(); + + root_dir + } + + #[tokio::test] + async fn dir_read_should_support_depth_limits() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let req = Request::new( + "test-tenant", + vec![RequestData::DirRead { + path: root_dir.path().to_path_buf(), + depth: 1, + absolute: false, + canonicalize: false, + include_root: false, + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + match &res.payload[0] { + ResponseData::DirEntries { entries, .. } => { + assert_eq!(entries.len(), 2, "Wrong number of entries found"); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Dir); + assert_eq!(entries[1].path, Path::new("sub1")); + assert_eq!(entries[1].depth, 1); + } + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn dir_read_should_support_unlimited_depth_using_zero() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let req = Request::new( + "test-tenant", + vec![RequestData::DirRead { + path: root_dir.path().to_path_buf(), + depth: 0, + absolute: false, + canonicalize: false, + include_root: false, + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + match &res.payload[0] { + ResponseData::DirEntries { entries, .. } => { + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, Path::new("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Dir); + assert_eq!(entries[1].path, Path::new("sub1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::File); + assert_eq!(entries[2].path, Path::new("sub1").join("file2")); + assert_eq!(entries[2].depth, 2); + } + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn dir_read_should_support_including_directory_in_returned_entries() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let req = Request::new( + "test-tenant", + vec![RequestData::DirRead { + path: root_dir.path().to_path_buf(), + depth: 1, + absolute: false, + canonicalize: false, + include_root: true, + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + match &res.payload[0] { + ResponseData::DirEntries { entries, .. } => { + assert_eq!(entries.len(), 3, "Wrong number of entries found"); + + // NOTE: Root entry is always absolute, resolved path + assert_eq!(entries[0].file_type, FileType::Dir); + assert_eq!(entries[0].path, root_dir.path().canonicalize().unwrap()); + assert_eq!(entries[0].depth, 0); + + assert_eq!(entries[1].file_type, FileType::File); + assert_eq!(entries[1].path, Path::new("file1")); + assert_eq!(entries[1].depth, 1); + + assert_eq!(entries[2].file_type, FileType::Dir); + assert_eq!(entries[2].path, Path::new("sub1")); + assert_eq!(entries[2].depth, 1); + } + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + async fn dir_read_should_support_returning_absolute_paths() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Create directory with some nested items + let root_dir = setup_dir().await; + + let req = Request::new( + "test-tenant", + vec![RequestData::DirRead { + path: root_dir.path().to_path_buf(), + depth: 1, + absolute: true, + canonicalize: false, + include_root: false, + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + match &res.payload[0] { + ResponseData::DirEntries { entries, .. } => { + assert_eq!(entries.len(), 2, "Wrong number of entries found"); + let root_path = root_dir.path().canonicalize().unwrap(); + + assert_eq!(entries[0].file_type, FileType::File); + assert_eq!(entries[0].path, root_path.join("file1")); + assert_eq!(entries[0].depth, 1); + + assert_eq!(entries[1].file_type, FileType::Dir); + assert_eq!(entries[1].path, root_path.join("sub1")); + assert_eq!(entries[1].depth, 1); + } + x => panic!("Unexpected response: {:?}", x), + } + } + + #[tokio::test] + #[ignore] + async fn dir_read_should_support_returning_canonicalized_paths() { + todo!("Figure out best way to support symlink tests"); + } + + #[tokio::test] + async fn dir_create_should_send_error_if_fails() { + let (conn_id, state, tx, mut rx) = setup(1); + + // Make a path that has multiple non-existent components + // so the creation will fail + let root_dir = setup_dir().await; + let path = root_dir.path().join("nested").join("new-dir"); + + let req = Request::new( + "test-tenant", + vec![RequestData::DirCreate { + path: path.to_path_buf(), + all: false, + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Error(_)), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that the directory was not actually created + assert!(!path.exists(), "Path unexpectedly exists"); + } + + #[tokio::test] + async fn dir_create_should_send_ok_when_successful() { + let (conn_id, state, tx, mut rx) = setup(1); + let root_dir = setup_dir().await; + let path = root_dir.path().join("new-dir"); + + let req = Request::new( + "test-tenant", + vec![RequestData::DirCreate { + path: path.to_path_buf(), + all: false, + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Ok), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that the directory was actually created + assert!(path.exists(), "Directory not created"); + } + + #[tokio::test] + async fn dir_create_should_support_creating_multiple_dir_components() { + let (conn_id, state, tx, mut rx) = setup(1); + let root_dir = setup_dir().await; + let path = root_dir.path().join("nested").join("new-dir"); + + let req = Request::new( + "test-tenant", + vec![RequestData::DirCreate { + path: path.to_path_buf(), + all: true, + }], + ); + + process(conn_id, state, req, tx).await.unwrap(); + + let res = rx.recv().await.unwrap(); + assert_eq!(res.payload.len(), 1, "Wrong payload size"); + assert!( + matches!(res.payload[0], ResponseData::Ok), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Also verify that the directory was actually created + assert!(path.exists(), "Directory not created"); + } +} diff --git a/src/core/server/distant/mod.rs b/src/core/server/distant/mod.rs index f5ca4ec..b4d8075 100644 --- a/src/core/server/distant/mod.rs +++ b/src/core/server/distant/mod.rs @@ -1,39 +1,37 @@ mod handler; -mod port; mod state; -mod utils; -pub use port::{PortRange, PortRangeParseError}; use state::State; use crate::core::{ data::{Request, Response}, net::{SecretKey, Transport, TransportReadHalf, TransportWriteHalf}, + server::{ + utils::{ConnTracker, ShutdownTask}, + PortRange, + }, }; +use futures::future::OptionFuture; use log::*; -use std::{ - net::{IpAddr, SocketAddr}, - sync::Arc, -}; +use std::{net::IpAddr, sync::Arc}; use tokio::{ io, net::{tcp, TcpListener, TcpStream}, - runtime::Handle, - sync::{mpsc, Mutex, Notify}, + sync::{mpsc, Mutex}, task::{JoinError, JoinHandle}, time::Duration, }; /// Represents a server that listens for requests, processes them, and sends responses -pub struct Server { +pub struct DistantServer { port: u16, - state: Arc>>, auth_key: Arc, - notify: Arc, conn_task: JoinHandle<()>, } -impl Server { +impl DistantServer { + /// Bind to an IP address and port from the given range, taking an optional shutdown duration + /// that will shutdown the server if there is no active connection after duration pub async fn bind( addr: IpAddr, port: PortRange, @@ -47,21 +45,19 @@ impl Server { debug!("Bound to port: {}", port); // Build our state for the server - let state: Arc>> = Arc::new(Mutex::new(State::default())); + let state: Arc> = Arc::new(Mutex::new(State::default())); let auth_key = Arc::new(SecretKey::default()); - let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after); + let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after); // Spawn our connection task - let state_2 = Arc::clone(&state); let auth_key_2 = Arc::clone(&auth_key); - let notify_2 = Arc::clone(¬ify); let conn_task = tokio::spawn(async move { connection_loop( listener, - state_2, + state, auth_key_2, - ct, - notify_2, + tracker, + shutdown, max_msg_capacity, ) .await @@ -69,9 +65,7 @@ impl Server { Ok(Self { port, - state, auth_key, - notify, conn_task, }) } @@ -90,33 +84,32 @@ impl Server { pub async fn wait(self) -> Result<(), JoinError> { self.conn_task.await } - - /// Shutdown the server - pub fn shutdown(&self) { - self.notify.notify_one() - } } async fn connection_loop( listener: TcpListener, - state: Arc>>, + state: Arc>, auth_key: Arc, - tracker: Arc>, - notify: Arc, + tracker: Option>>, + shutdown: OptionFuture, max_msg_capacity: usize, ) { - loop { - tokio::select! { - result = listener.accept() => {match result { + let inner = async move { + loop { + match listener.accept().await { Ok((conn, addr)) => { + let conn_id = rand::random(); + debug!(" Established against {}", conn_id, addr); if let Err(x) = on_new_conn( conn, - addr, + conn_id, Arc::clone(&state), Arc::clone(&auth_key), - Arc::clone(&tracker), - max_msg_capacity - ).await { + tracker.as_ref().map(Arc::clone), + max_msg_capacity, + ) + .await + { error!(" Failed handshake: {}", addr, x); } } @@ -124,12 +117,15 @@ async fn connection_loop( error!("Listener failed: {}", x); break; } - }} - _ = notify.notified() => { - warn!("Reached shutdown timeout, so terminating"); - break; } } + }; + + tokio::select! { + _ = inner => {} + _ = shutdown => { + warn!("Reached shutdown timeout, so terminating"); + } } } @@ -137,10 +133,10 @@ async fn connection_loop( /// input and output, returning join handles for the input and output tasks respectively async fn on_new_conn( conn: TcpStream, - addr: SocketAddr, - state: Arc>>, + conn_id: usize, + state: Arc>, auth_key: Arc, - tracker: Arc>, + tracker: Option>>, max_msg_capacity: usize, ) -> io::Result<(JoinHandle<()>, JoinHandle<()>)> { // Establish a proper connection via a handshake, @@ -151,23 +147,26 @@ async fn on_new_conn( // and output concurrently let (t_read, t_write) = transport.into_split(); let (tx, rx) = mpsc::channel(max_msg_capacity); - let ct_2 = Arc::clone(&tracker); // Spawn a new task that loops to handle requests from the client let req_task = tokio::spawn({ - let f = request_loop(addr, Arc::clone(&state), t_read, tx); + let f = request_loop(conn_id, Arc::clone(&state), t_read, tx); let state = Arc::clone(&state); async move { - ct_2.lock().await.increment(); + if let Some(ct) = tracker.as_ref() { + ct.lock().await.increment(); + } f.await; - state.lock().await.cleanup_client(addr).await; - ct_2.lock().await.decrement(); + state.lock().await.cleanup_connection(conn_id).await; + if let Some(ct) = tracker.as_ref() { + ct.lock().await.decrement(); + } } }); // Spawn a new task that loops to handle responses to the client - let res_task = tokio::spawn(async move { response_loop(addr, t_write, rx).await }); + let res_task = tokio::spawn(async move { response_loop(conn_id, t_write, rx).await }); Ok((req_task, res_task)) } @@ -175,8 +174,8 @@ async fn on_new_conn( /// Repeatedly reads in new requests, processes them, and sends their responses to the /// response loop async fn request_loop( - addr: SocketAddr, - state: Arc>>, + conn_id: usize, + state: Arc>, mut transport: TransportReadHalf, tx: mpsc::Sender, ) { @@ -185,22 +184,23 @@ async fn request_loop( Ok(Some(req)) => { debug!( " Received request of type{} {}", - addr, + conn_id, if req.payload.len() > 1 { "s" } else { "" }, req.to_payload_type_string() ); - if let Err(x) = handler::process(addr, Arc::clone(&state), req, tx.clone()).await { - error!(" {}", addr, x); + if let Err(x) = handler::process(conn_id, Arc::clone(&state), req, tx.clone()).await + { + error!(" {}", conn_id, x); break; } } Ok(None) => { - info!(" Closed connection", addr); + info!(" Closed connection", conn_id); break; } Err(x) => { - error!(" {}", addr, x); + error!(" {}", conn_id, x); break; } } @@ -209,13 +209,13 @@ async fn request_loop( /// Repeatedly sends responses out over the wire async fn response_loop( - addr: SocketAddr, + conn_id: usize, mut transport: TransportWriteHalf, mut rx: mpsc::Receiver, ) { while let Some(res) = rx.recv().await { if let Err(x) = transport.send(res).await { - error!(" {}", addr, x); + error!(" {}", conn_id, x); break; } } diff --git a/src/core/server/distant/port.rs b/src/core/server/distant/port.rs deleted file mode 100644 index 58272be..0000000 --- a/src/core/server/distant/port.rs +++ /dev/null @@ -1,66 +0,0 @@ -use derive_more::{Display, Error}; -use std::{ - net::{IpAddr, SocketAddr}, - str::FromStr, -}; - -/// Represents some range of ports -#[derive(Clone, Debug, Display, PartialEq, Eq)] -#[display( - fmt = "{}{}", - start, - "end.as_ref().map(|end| format!(\"[:{}]\", end)).unwrap_or_default()" -)] -pub struct PortRange { - pub start: u16, - pub end: Option, -} - -impl PortRange { - /// Builds a collection of `SocketAddr` instances from the port range and given ip address - pub fn make_socket_addrs(&self, addr: impl Into) -> Vec { - let mut socket_addrs = Vec::new(); - let addr = addr.into(); - - for port in self.start..=self.end.unwrap_or(self.start) { - socket_addrs.push(SocketAddr::from((addr, port))); - } - - socket_addrs - } -} - -#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)] -pub enum PortRangeParseError { - InvalidPort, - MissingPort, -} - -impl FromStr for PortRange { - type Err = PortRangeParseError; - - /// Parses PORT into single range or PORT1:PORTN into full range - fn from_str(s: &str) -> Result { - let mut tokens = s.trim().split(':'); - let start = tokens - .next() - .ok_or(PortRangeParseError::MissingPort)? - .parse::() - .map_err(|_| PortRangeParseError::InvalidPort)?; - let end = if let Some(token) = tokens.next() { - Some( - token - .parse::() - .map_err(|_| PortRangeParseError::InvalidPort)?, - ) - } else { - None - }; - - if tokens.next().is_some() { - return Err(PortRangeParseError::InvalidPort); - } - - Ok(Self { start, end }) - } -} diff --git a/src/core/server/distant/state.rs b/src/core/server/distant/state.rs index 5f22193..43ba9ec 100644 --- a/src/core/server/distant/state.rs +++ b/src/core/server/distant/state.rs @@ -1,46 +1,41 @@ use log::*; -use std::{collections::HashMap, fmt::Debug, hash::Hash}; +use std::collections::HashMap; use tokio::sync::{mpsc, oneshot}; /// Holds state related to multiple clients managed by a server -pub struct State -where - ClientId: Debug + Hash + PartialEq + Eq, -{ +#[derive(Default)] +pub struct State { /// Map of all processes running on the server pub processes: HashMap, /// List of processes that will be killed when a client drops - client_processes: HashMap>, + client_processes: HashMap>, } -impl State -where - ClientId: Debug + Hash + PartialEq + Eq, -{ +impl State { /// Pushes a new process associated with a client - pub fn push_process(&mut self, client_id: ClientId, process: Process) { + pub fn push_process(&mut self, conn_id: usize, process: Process) { self.client_processes - .entry(client_id) + .entry(conn_id) .or_insert(Vec::new()) .push(process.id); self.processes.insert(process.id, process); } - /// Cleans up state associated with a particular client - pub async fn cleanup_client(&mut self, client_id: ClientId) { - debug!(" Cleaning up state", client_id); - if let Some(ids) = self.client_processes.remove(&client_id) { + /// Cleans up state associated with a particular connection + pub async fn cleanup_connection(&mut self, conn_id: usize) { + debug!(" Cleaning up state", conn_id); + if let Some(ids) = self.client_processes.remove(&conn_id) { for id in ids { if let Some(process) = self.processes.remove(&id) { trace!( - " Requesting proc {} be killed", - client_id, + " Requesting proc {} be killed", + conn_id, process.id ); if let Err(_) = process.kill_tx.send(()) { error!( - "Client {} failed to send process {} kill signal", + "Conn {} failed to send process {} kill signal", id, process.id ); } @@ -50,18 +45,6 @@ where } } -impl Default for State -where - ClientId: Debug + Hash + PartialEq + Eq, -{ - fn default() -> Self { - Self { - processes: HashMap::new(), - client_processes: HashMap::new(), - } - } -} - /// Represents an actively-running process pub struct Process { /// Id of the process diff --git a/src/core/server/distant/utils.rs b/src/core/server/distant/utils.rs deleted file mode 100644 index 9808d0c..0000000 --- a/src/core/server/distant/utils.rs +++ /dev/null @@ -1,101 +0,0 @@ -use log::*; -use std::{sync::Arc, time::Duration}; -use tokio::{ - runtime::Handle, - sync::{Mutex, Notify}, - time::{self, Instant}, -}; - -pub struct ConnTracker { - time: Instant, - cnt: usize, -} - -impl ConnTracker { - pub fn new() -> Self { - Self { - time: Instant::now(), - cnt: 0, - } - } - - pub fn increment(&mut self) { - self.time = Instant::now(); - self.cnt += 1; - } - - pub fn decrement(&mut self) { - if self.cnt > 0 { - self.time = Instant::now(); - self.cnt -= 1; - } - } - - pub fn time_and_cnt(&self) -> (Instant, usize) { - (self.time, self.cnt) - } - - pub fn has_exceeded_timeout(&self, duration: Duration) -> bool { - self.cnt == 0 && self.time.elapsed() > duration - } -} - -/// Spawns a new task that continues to monitor the time since a -/// connection on the server existed, shutting down the runtime -/// if the time is exceeded -pub fn new_shutdown_task( - handle: Handle, - duration: Option, -) -> (Arc>, Arc) { - let ct = Arc::new(Mutex::new(ConnTracker::new())); - let notify = Arc::new(Notify::new()); - - let ct_2 = Arc::clone(&ct); - let notify_2 = Arc::clone(¬ify); - if let Some(duration) = duration { - handle.spawn(async move { - loop { - // Get the time since the last connection joined/left - let (base_time, cnt) = ct_2.lock().await.time_and_cnt(); - - // If we have no connections left, we want to wait - // until the remaining period has passed and then - // verify that we still have no connections - if cnt == 0 { - // Get the time we should wait based on when the last connection - // was dropped; this closes the gap in the case where we start - // sometime later than exactly duration since the last check - let next_time = base_time + duration; - let wait_duration = next_time - .checked_duration_since(Instant::now()) - .unwrap_or_default() - + Duration::from_millis(1); - - // Wait until we've reached our desired duration since the - // last connection was dropped - time::sleep(wait_duration).await; - - // If we do have a connection at this point, don't exit - if !ct_2.lock().await.has_exceeded_timeout(duration) { - continue; - } - - // Otherwise, we now should exit, which we do by reporting - debug!( - "Shutdown time of {}s has been reached!", - duration.as_secs_f32() - ); - notify_2.notify_one(); - break; - } - - // Otherwise, we just wait the full duration as worst case - // we'll have waited just about the time desired if right - // after waiting starts the last connection is closed - time::sleep(duration).await; - } - }); - } - - (ct, notify) -} diff --git a/src/core/server/mod.rs b/src/core/server/mod.rs index f65409b..60e40ca 100644 --- a/src/core/server/mod.rs +++ b/src/core/server/mod.rs @@ -1,2 +1,8 @@ mod distant; -pub use distant::{PortRange, PortRangeParseError, Server as DistantServer}; +mod port; +mod relay; +mod utils; + +pub use self::distant::DistantServer; +pub use port::PortRange; +pub use relay::RelayServer; diff --git a/src/core/server/port.rs b/src/core/server/port.rs new file mode 100644 index 0000000..50600bd --- /dev/null +++ b/src/core/server/port.rs @@ -0,0 +1,180 @@ +use derive_more::Display; +use std::{ + net::{IpAddr, SocketAddr}, + ops::RangeInclusive, + str::FromStr, +}; + +/// Represents some range of ports +#[derive(Clone, Debug, Display, PartialEq, Eq)] +#[display( + fmt = "{}{}", + start, + "end.as_ref().map(|end| format!(\":{}\", end)).unwrap_or_default()" +)] +pub struct PortRange { + pub start: u16, + pub end: Option, +} + +impl PortRange { + /// Builds a collection of `SocketAddr` instances from the port range and given ip address + pub fn make_socket_addrs(&self, addr: impl Into) -> Vec { + let mut socket_addrs = Vec::new(); + let addr = addr.into(); + + for port in self { + socket_addrs.push(SocketAddr::from((addr, port))); + } + + socket_addrs + } +} + +impl From> for PortRange { + fn from(r: RangeInclusive) -> Self { + let (start, end) = r.into_inner(); + Self { + start, + end: Some(end), + } + } +} + +impl<'a> IntoIterator for &'a PortRange { + type Item = u16; + type IntoIter = RangeInclusive; + + fn into_iter(self) -> Self::IntoIter { + self.start..=self.end.unwrap_or(self.start) + } +} + +impl IntoIterator for PortRange { + type Item = u16; + type IntoIter = RangeInclusive; + + fn into_iter(self) -> Self::IntoIter { + self.start..=self.end.unwrap_or(self.start) + } +} + +impl FromStr for PortRange { + type Err = std::num::ParseIntError; + + /// Parses PORT into single range or PORT1:PORTN into full range + fn from_str(s: &str) -> Result { + match s.split_once(':') { + Some((start, end)) => Ok(Self { + start: start.parse()?, + end: Some(end.parse()?), + }), + None => Ok(Self { + start: s.parse()?, + end: None, + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display_should_properly_reflect_port_range() { + let p = PortRange { + start: 100, + end: None, + }; + assert_eq!(p.to_string(), "100"); + + let p = PortRange { + start: 100, + end: Some(200), + }; + assert_eq!(p.to_string(), "100:200"); + } + + #[test] + fn from_range_inclusive_should_map_to_port_range() { + let p = PortRange::from(100..=200); + assert_eq!(p.start, 100); + assert_eq!(p.end, Some(200)); + } + + #[test] + fn into_iterator_should_support_port_range() { + let p = PortRange { + start: 1, + end: None, + }; + assert_eq!((&p).into_iter().collect::>(), vec![1]); + assert_eq!(p.into_iter().collect::>(), vec![1]); + + let p = PortRange { + start: 1, + end: Some(3), + }; + assert_eq!((&p).into_iter().collect::>(), vec![1, 2, 3]); + assert_eq!(p.into_iter().collect::>(), vec![1, 2, 3]); + } + + #[test] + fn make_socket_addrs_should_produce_a_socket_addr_per_port() { + let ip_addr = "127.0.0.1".parse::().unwrap(); + + let p = PortRange { + start: 1, + end: None, + }; + assert_eq!( + p.make_socket_addrs(ip_addr), + vec![SocketAddr::new(ip_addr, 1)] + ); + + let p = PortRange { + start: 1, + end: Some(3), + }; + assert_eq!( + p.make_socket_addrs(ip_addr), + vec![ + SocketAddr::new(ip_addr, 1), + SocketAddr::new(ip_addr, 2), + SocketAddr::new(ip_addr, 3), + ] + ); + } + + #[test] + fn parse_should_fail_if_not_starting_with_number() { + assert!("100a".parse::().is_err()); + } + + #[test] + fn parse_should_fail_if_provided_end_port_that_is_not_a_number() { + assert!("100:200a".parse::().is_err()); + } + + #[test] + fn parse_should_be_able_to_properly_read_in_port_range() { + let p: PortRange = "100".parse().unwrap(); + assert_eq!( + p, + PortRange { + start: 100, + end: None + } + ); + + let p: PortRange = "100:200".parse().unwrap(); + assert_eq!( + p, + PortRange { + start: 100, + end: Some(200) + } + ); + } +} diff --git a/src/core/server/relay.rs b/src/core/server/relay.rs new file mode 100644 index 0000000..cdaeb0d --- /dev/null +++ b/src/core/server/relay.rs @@ -0,0 +1,336 @@ +use crate::core::{ + client::Session, + constants::CLIENT_BROADCAST_CHANNEL_CAPACITY, + data::{Request, RequestData, Response, ResponseData}, + net::{DataStream, Listener, Transport, TransportReadHalf, TransportWriteHalf}, + server::utils::{ConnTracker, ShutdownTask}, +}; +use log::*; +use std::{collections::HashMap, marker::Unpin, sync::Arc}; +use tokio::{ + io::{self, AsyncRead, AsyncWrite}, + sync::{broadcast, mpsc, oneshot, Mutex}, + task::{JoinError, JoinHandle}, + time::Duration, +}; + +/// Represents a server that relays requests & responses between connections and the +/// actual server +pub struct RelayServer { + forward_task: JoinHandle<()>, + accept_task: JoinHandle<()>, + conns: Arc>>, +} + +impl RelayServer { + pub async fn initialize( + mut session: Session, + listener: L, + shutdown_after: Option, + ) -> io::Result + where + T1: DataStream + 'static, + T2: DataStream + Send + 'static, + L: Listener + 'static, + { + // Get a copy of our session's broadcaster so we can have each connection + // subscribe to it for new messages filtered by tenant + debug!("Acquiring session broadcaster"); + let broadcaster = session.to_response_broadcaster(); + + // Spawn task to send to the server requests from connections + debug!("Spawning request forwarding task"); + let (req_tx, mut req_rx) = mpsc::channel::(CLIENT_BROADCAST_CHANNEL_CAPACITY); + let forward_task = tokio::spawn(async move { + while let Some(req) = req_rx.recv().await { + debug!( + "Forwarding request of type{} {} to server", + if req.payload.len() > 1 { "s" } else { "" }, + req.to_payload_type_string() + ); + if let Err(x) = session.fire(req).await { + error!("Session failed to send request: {:?}", x); + break; + } + } + }); + + let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after); + let conns = Arc::new(Mutex::new(HashMap::new())); + let conns_2 = Arc::clone(&conns); + let accept_task = tokio::spawn(async move { + let inner = async move { + loop { + match listener.accept().await { + Ok(stream) => { + let result = Conn::initialize( + stream, + req_tx.clone(), + broadcaster.clone(), + tracker.as_ref().map(Arc::clone), + ) + .await; + + match result { + Ok(conn) => conns_2.lock().await.insert(conn.id(), conn), + Err(x) => { + error!("Failed to initialize connection: {}", x); + continue; + } + }; + } + Err(x) => { + debug!("Listener has closed: {}", x); + break; + } + } + } + }; + + tokio::select! { + _ = inner => {} + _ = shutdown => { + warn!("Reached shutdown timeout, so terminating"); + } + } + }); + + Ok(Self { + forward_task, + accept_task, + conns, + }) + } + + pub async fn wait(self) -> Result<(), JoinError> { + match tokio::try_join!(self.forward_task, self.accept_task) { + Ok(_) => Ok(()), + Err(x) => Err(x), + } + } + + pub async fn abort(&self) { + self.forward_task.abort(); + self.accept_task.abort(); + self.conns + .lock() + .await + .values() + .for_each(|conn| conn.abort()); + } +} + +struct Conn { + id: usize, + req_task: JoinHandle<()>, + res_task: JoinHandle<()>, +} + +/// Represents state associated with a connection +#[derive(Default)] +struct ConnState { + processes: Vec, +} + +impl Conn { + pub async fn initialize( + stream: T, + req_tx: mpsc::Sender, + res_broadcaster: broadcast::Sender, + ct: Option>>, + ) -> io::Result + where + T: DataStream + 'static, + { + // Create a unique id to associate with the connection since its address + // is not guaranteed to have an identifiable string + let id: usize = rand::random(); + + // Establish a proper connection via a handshake, discarding the connection otherwise + let transport = Transport::from_handshake(stream, None).await.map_err(|x| { + error!(" Failed handshake: {}", id, x); + io::Error::new(io::ErrorKind::Other, x) + })?; + let (t_read, t_write) = transport.into_split(); + + // Used to alert our response task of the connection's tenant name + // based on the first + let (tenant_tx, tenant_rx) = oneshot::channel(); + + // Create a state we use to keep track of connection-specific data + debug!(" Initializing internal state", id); + let state = Arc::new(Mutex::new(ConnState::default())); + + // Spawn task to continually receive responses from the session that + // may or may not be relevant to the connection, which will filter + // by tenant and then along any response that matches + let res_rx = res_broadcaster.subscribe(); + let state_2 = Arc::clone(&state); + let res_task = tokio::spawn(async move { + handle_conn_outgoing(id, state_2, t_write, tenant_rx, res_rx).await; + }); + + // Spawn task to continually read requests from connection and forward + // them along to be sent via the session + let req_tx = req_tx.clone(); + let req_task = tokio::spawn(async move { + if let Some(ct) = ct.as_ref() { + ct.lock().await.increment(); + } + handle_conn_incoming(id, state, t_read, tenant_tx, req_tx).await; + if let Some(ct) = ct.as_ref() { + ct.lock().await.decrement(); + } + debug!(" Disconnected", id); + }); + + Ok(Self { + id, + req_task, + res_task, + }) + } + + /// Id associated with the connection + pub fn id(&self) -> usize { + self.id + } + + /// Aborts the connection from the server side + pub fn abort(&self) { + self.req_task.abort(); + self.res_task.abort(); + } +} + +/// Conn::Request -> Session::Fire +async fn handle_conn_incoming( + conn_id: usize, + state: Arc>, + mut reader: TransportReadHalf, + tenant_tx: oneshot::Sender, + req_tx: mpsc::Sender, +) where + T: AsyncRead + Unpin, +{ + macro_rules! process_req { + ($on_success:expr; $done:expr) => { + match reader.receive::().await { + Ok(Some(req)) => { + $on_success(&req); + if let Err(x) = req_tx.send(req).await { + error!( + "Failed to pass along request received on unix socket: {:?}", + x + ); + $done; + } + } + Ok(None) => $done, + Err(x) => { + error!("Failed to receive request from unix stream: {:?}", x); + $done; + } + } + }; + } + + let mut tenant = None; + + // NOTE: Have to acquire our first request outside our loop since the oneshot + // sender of the tenant's name is consuming + process_req!( + |req: &Request| { + tenant = Some(req.tenant.clone()); + if let Err(x) = tenant_tx.send(req.tenant.clone()) { + error!("Failed to send along acquired tenant name: {:?}", x); + return; + } + }; + return + ); + + // Loop and process all additional requests + loop { + process_req!(|_| {}; break); + } + + // At this point, we have processed at least one request successfully + // and should have the tenant populated. If we had a failure at the + // beginning, we exit the function early via return. + let tenant = tenant.unwrap(); + + // Perform cleanup if done by sending a request to kill each running process + // debug!("Cleaning conn {} :: killing process {}", conn_id, id); + if let Err(x) = req_tx + .send(Request::new( + tenant.clone(), + state + .lock() + .await + .processes + .iter() + .map(|id| RequestData::ProcKill { id: *id }) + .collect(), + )) + .await + { + error!(" Failed to send kill signals: {}", conn_id, x); + } +} + +async fn handle_conn_outgoing( + conn_id: usize, + state: Arc>, + mut writer: TransportWriteHalf, + tenant_rx: oneshot::Receiver, + mut res_rx: broadcast::Receiver, +) where + T: AsyncWrite + Unpin, +{ + // We wait for the tenant to be identified by the first request + // before processing responses to be sent back; this is easier + // to implement and yields the same result as we would be dropping + // all responses before we know the tenant + if let Ok(tenant) = tenant_rx.await { + debug!("Associated tenant {} with conn {}", tenant, conn_id); + loop { + match res_rx.recv().await { + // Forward along responses that are for our connection + Ok(res) if res.tenant == tenant => { + debug!( + "Conn {} being sent response of type{} {}", + conn_id, + if res.payload.len() > 1 { "s" } else { "" }, + res.to_payload_type_string(), + ); + + // If a new process was started, we want to capture the id and + // associate it with the connection + let ids = res.payload.iter().filter_map(|x| match x { + ResponseData::ProcStart { id } => Some(*id), + _ => None, + }); + for id in ids { + debug!("Tracking proc {} for conn {}", id, conn_id); + state.lock().await.processes.push(id); + } + + if let Err(x) = writer.send(res).await { + error!("Failed to send response through unix connection: {}", x); + break; + } + } + // Skip responses that are not for our connection + Ok(_) => {} + Err(x) => { + error!( + "Conn {} failed to receive broadcast response: {}", + conn_id, x + ); + break; + } + } + } + } +} diff --git a/src/core/server/utils.rs b/src/core/server/utils.rs new file mode 100644 index 0000000..c659433 --- /dev/null +++ b/src/core/server/utils.rs @@ -0,0 +1,291 @@ +use futures::future::OptionFuture; +use log::*; +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; +use tokio::{ + sync::Mutex, + task::{JoinError, JoinHandle}, + time::{self, Instant}, +}; + +/// Task to keep track of a possible server shutdown based on connections +pub struct ShutdownTask { + task: JoinHandle<()>, + tracker: Arc>, +} + +impl Future for ShutdownTask { + type Output = Result<(), JoinError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.task).poll(cx) + } +} + +impl ShutdownTask { + /// Given an optional timeout, will either create the shutdown task or not, + /// returning an optional future for the completion of the shutdown task + /// alongside an optional connection tracker + pub fn maybe_initialize( + duration: Option, + ) -> (OptionFuture, Option>>) { + match duration { + Some(duration) => { + let task = Self::initialize(duration); + let tracker = task.tracker(); + let task: OptionFuture<_> = Some(task).into(); + (task, Some(tracker)) + } + None => (None.into(), None), + } + } + + /// Spawns a new task that continues to monitor the time since a + /// connection on the server existed, reporting a shutdown to all listeners + /// once the timeout is exceeded + pub fn initialize(duration: Duration) -> Self { + let tracker = Arc::new(Mutex::new(ConnTracker::new())); + + let tracker_2 = Arc::clone(&tracker); + let task = tokio::spawn(async move { + loop { + // Get the time since the last connection joined/left + let (base_time, cnt) = tracker_2.lock().await.time_and_cnt(); + + // If we have no connections left, we want to wait + // until the remaining period has passed and then + // verify that we still have no connections + if cnt == 0 { + // Get the time we should wait based on when the last connection + // was dropped; this closes the gap in the case where we start + // sometime later than exactly duration since the last check + let next_time = base_time + duration; + let wait_duration = next_time + .checked_duration_since(Instant::now()) + .unwrap_or_default() + + Duration::from_millis(1); + + // Wait until we've reached our desired duration since the + // last connection was dropped + time::sleep(wait_duration).await; + + // If we do have a connection at this point, don't exit + if !tracker_2.lock().await.has_reached_timeout(duration) { + continue; + } + + // Otherwise, we now should exit, which we do by reporting + debug!( + "Shutdown time of {}s has been reached!", + duration.as_secs_f32() + ); + break; + } + + // Otherwise, we just wait the full duration as worst case + // we'll have waited just about the time desired if right + // after waiting starts the last connection is closed + time::sleep(duration).await; + } + }); + + Self { task, tracker } + } + + /// Produces a new copy of the connection tracker associated with the shutdown manager + pub fn tracker(&self) -> Arc> { + Arc::clone(&self.tracker) + } +} + +pub struct ConnTracker { + time: Instant, + cnt: usize, +} + +impl ConnTracker { + pub fn new() -> Self { + Self { + time: Instant::now(), + cnt: 0, + } + } + + pub fn increment(&mut self) { + self.time = Instant::now(); + self.cnt += 1; + } + + pub fn decrement(&mut self) { + if self.cnt > 0 { + self.time = Instant::now(); + self.cnt -= 1; + } + } + + fn time_and_cnt(&self) -> (Instant, usize) { + (self.time, self.cnt) + } + + fn has_reached_timeout(&self, duration: Duration) -> bool { + self.cnt == 0 && self.time.elapsed() >= duration + } +} + +#[cfg(test)] +mod tsets { + use super::*; + use std::thread; + + #[tokio::test] + async fn shutdown_task_should_not_resolve_if_has_connection_regardless_of_time() { + let mut task = ShutdownTask::initialize(Duration::from_millis(10)); + task.tracker().lock().await.increment(); + assert!( + futures::poll!(&mut task).is_pending(), + "Shutdown task unexpectedly completed" + ); + + time::sleep(Duration::from_millis(15)).await; + + assert!( + futures::poll!(task).is_pending(), + "Shutdown task unexpectedly completed" + ); + } + + #[tokio::test] + async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration() { + let mut task = ShutdownTask::initialize(Duration::from_millis(10)); + assert!( + futures::poll!(&mut task).is_pending(), + "Shutdown task unexpectedly completed" + ); + + time::sleep(Duration::from_millis(15)).await; + + assert!( + futures::poll!(task).is_ready(), + "Shutdown task unexpectedly pending" + ); + } + + #[tokio::test] + async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration_after_connection_removed( + ) { + let mut task = ShutdownTask::initialize(Duration::from_millis(10)); + task.tracker().lock().await.increment(); + assert!( + futures::poll!(&mut task).is_pending(), + "Shutdown task unexpectedly completed" + ); + + time::sleep(Duration::from_millis(15)).await; + assert!( + futures::poll!(&mut task).is_pending(), + "Shutdown task unexpectedly completed" + ); + + task.tracker().lock().await.decrement(); + time::sleep(Duration::from_millis(15)).await; + + assert!( + futures::poll!(task).is_ready(), + "Shutdown task unexpectedly pending" + ); + } + + #[tokio::test] + async fn shutdown_task_should_not_resolve_before_minimum_duration() { + let mut task = ShutdownTask::initialize(Duration::from_millis(10)); + assert!( + futures::poll!(&mut task).is_pending(), + "Shutdown task unexpectedly completed" + ); + + time::sleep(Duration::from_millis(5)).await; + + assert!( + futures::poll!(task).is_pending(), + "Shutdown task unexpectedly completed" + ); + } + + #[test] + fn conn_tracker_should_update_time_when_incremented() { + let mut tracker = ConnTracker::new(); + let (old_time, cnt) = tracker.time_and_cnt(); + assert_eq!(cnt, 0); + + // Wait to ensure that the new time will be different + thread::sleep(Duration::from_millis(1)); + + tracker.increment(); + let (new_time, cnt) = tracker.time_and_cnt(); + assert_eq!(cnt, 1); + assert!(new_time > old_time); + } + + #[test] + fn conn_tracker_should_update_time_when_decremented() { + let mut tracker = ConnTracker::new(); + tracker.increment(); + + let (old_time, cnt) = tracker.time_and_cnt(); + assert_eq!(cnt, 1); + + // Wait to ensure that the new time will be different + thread::sleep(Duration::from_millis(1)); + + tracker.decrement(); + let (new_time, cnt) = tracker.time_and_cnt(); + assert_eq!(cnt, 0); + assert!(new_time > old_time); + } + + #[test] + fn conn_tracker_should_not_update_time_when_decremented_if_at_zero_already() { + let mut tracker = ConnTracker::new(); + let (old_time, cnt) = tracker.time_and_cnt(); + assert_eq!(cnt, 0); + + // Wait to ensure that the new time would be different if updated + thread::sleep(Duration::from_millis(1)); + + tracker.decrement(); + let (new_time, cnt) = tracker.time_and_cnt(); + assert_eq!(cnt, 0); + assert!(new_time == old_time); + } + + #[test] + fn conn_tracker_should_report_timeout_reached_when_time_has_elapsed_and_no_connections() { + let tracker = ConnTracker::new(); + let (_, cnt) = tracker.time_and_cnt(); + assert_eq!(cnt, 0); + + // Wait to ensure that the new time would be different if updated + thread::sleep(Duration::from_millis(1)); + + assert!(tracker.has_reached_timeout(Duration::from_millis(1))); + } + + #[test] + fn conn_tracker_should_not_report_timeout_reached_when_time_has_elapsed_but_has_connections() { + let mut tracker = ConnTracker::new(); + tracker.increment(); + + let (_, cnt) = tracker.time_and_cnt(); + assert_eq!(cnt, 1); + + // Wait to ensure that the new time would be different if updated + thread::sleep(Duration::from_millis(1)); + + assert!(!tracker.has_reached_timeout(Duration::from_millis(1))); + } +} diff --git a/src/lib.rs b/src/lib.rs index e7374ac..29099b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ mod cli; mod core; -pub use self::core::{data, net}; +pub use self::core::{client::*, data, net, server::*}; use log::error; /// Main entrypoint into the program