diff --git a/core/README.md b/core/README.md index 272b243..cdea14a 100644 --- a/core/README.md +++ b/core/README.md @@ -12,10 +12,10 @@ servers that operate on remote machines and clients that talk to them. - Asynchronous in nature, powered by [`tokio`](https://tokio.rs/) - Data is serialized to send across the wire via [`CBOR`](https://cbor.io/) -- Encryption & authentication are handled via [`orion`](https://crates.io/crates/orion) - - [XChaCha20Poly1305](https://cryptopp.com/wiki/XChaCha20Poly1305) for an authenticated encryption scheme - - [BLAKE2b-256](https://www.blake2.net/) in keyed mode for a second authentication - - [Elliptic Curve Diffie-Hellman](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman) (ECDH) for key exchange +- Encryption & authentication are handled via + [XChaCha20Poly1305](https://tools.ietf.org/html/rfc8439) for an authenticated + encryption scheme via + [RustCrypto/ChaCha20Poly1305](https://github.com/RustCrypto/AEADs/tree/master/chacha20poly1305) ## Installation diff --git a/core/src/client/lsp/mod.rs b/core/src/client/lsp/mod.rs index d72ee15..62bbb89 100644 --- a/core/src/client/lsp/mod.rs +++ b/core/src/client/lsp/mod.rs @@ -1,8 +1,5 @@ use super::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout}; -use crate::{ - client::Session, - net::{Codec, DataStream}, -}; +use crate::client::Session; use futures::stream::{Stream, StreamExt}; use std::{ fmt::Write, @@ -26,16 +23,12 @@ pub struct RemoteLspProcess { impl RemoteLspProcess { /// Spawns the specified process on the remote machine using the given session, treating /// the process like an LSP server - pub async fn spawn( - tenant: String, - session: Session, - cmd: String, + pub async fn spawn( + tenant: impl Into, + session: &mut Session, + cmd: impl Into, args: Vec, - ) -> Result - where - T: DataStream + 'static, - U: Codec + Send + 'static, - { + ) -> Result { let mut inner = RemoteProcess::spawn(tenant, session, cmd, args).await?; let stdin = inner.stdin.take().map(RemoteLspStdin::new); let stdout = inner.stdout.take().map(RemoteLspStdout::new); @@ -266,13 +259,16 @@ mod tests { // Configures an lsp process with a means to send & receive data from outside async fn spawn_lsp_process() -> (Transport, RemoteLspProcess) { let (mut t1, t2) = Transport::make_pair(); - let session = Session::initialize(t2).unwrap(); - let spawn_task = tokio::spawn(RemoteLspProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let mut session = Session::initialize(t2).unwrap(); + let spawn_task = tokio::spawn(async move { + RemoteLspProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = t1.receive::().await.unwrap().unwrap(); @@ -280,7 +276,7 @@ mod tests { // Send back a response through the session t1.send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::ProcStart { id: rand::random() }], )) .await @@ -524,7 +520,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStdout { id: proc.id(), data: make_lsp_msg(serde_json::json!({ @@ -561,7 +557,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStdout { id: proc.id(), data: msg_a.to_string(), @@ -580,7 +576,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStdout { id: proc.id(), data: msg_b.to_string(), @@ -615,7 +611,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStdout { id: proc.id(), data: format!("{}{}", msg, extra), @@ -659,7 +655,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStdout { id: proc.id(), data: format!("{}{}", msg_1, msg_2), @@ -694,7 +690,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStdout { id: proc.id(), data: make_lsp_msg(serde_json::json!({ @@ -725,7 +721,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStderr { id: proc.id(), data: make_lsp_msg(serde_json::json!({ @@ -762,7 +758,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStderr { id: proc.id(), data: msg_a.to_string(), @@ -781,7 +777,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStderr { id: proc.id(), data: msg_b.to_string(), @@ -816,7 +812,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStderr { id: proc.id(), data: format!("{}{}", msg, extra), @@ -860,7 +856,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStderr { id: proc.id(), data: format!("{}{}", msg_1, msg_2), @@ -895,7 +891,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + proc.origin_id, vec![ResponseData::ProcStderr { id: proc.id(), data: make_lsp_msg(serde_json::json!({ diff --git a/core/src/client/process.rs b/core/src/client/process.rs index 0ba2a05..35405b1 100644 --- a/core/src/client/process.rs +++ b/core/src/client/process.rs @@ -1,8 +1,8 @@ use crate::{ - client::Session, - constants::CLIENT_BROADCAST_CHANNEL_CAPACITY, - data::{Request, RequestData, Response, ResponseData}, - net::{Codec, DataStream, TransportError}, + client::{Mailbox, Session, SessionChannel}, + constants::CLIENT_MAILBOX_CAPACITY, + data::{Request, RequestData, ResponseData}, + net::TransportError, }; use derive_more::{Display, Error, From}; use log::*; @@ -14,9 +14,6 @@ use tokio::{ #[derive(Debug, Display, Error, From)] pub enum RemoteProcessError { - /// When the process receives an unexpected response - BadResponse, - /// When attempting to relay stdout/stderr over channels, but the channels fail ChannelDead, @@ -37,6 +34,9 @@ pub struct RemoteProcess { /// Id of the process id: usize, + /// Id used to map back to mailbox + pub(crate) origin_id: usize, + /// Task that forwards stdin to the remote process by bundling it as stdin requests req_task: JoinHandle>, @@ -59,39 +59,55 @@ pub struct RemoteProcess { impl RemoteProcess { /// Spawns the specified process on the remote machine using the given session - pub async fn spawn( - tenant: String, - mut session: Session, - cmd: String, + pub async fn spawn( + tenant: impl Into, + session: &mut Session, + cmd: impl Into, args: Vec, - ) -> Result - where - T: DataStream + 'static, - U: Codec + Send + 'static, - { - // Submit our run request and wait for a response - let res = session - .send(Request::new( + ) -> Result { + let tenant = tenant.into(); + let cmd = cmd.into(); + + // Submit our run request and get back a mailbox for responses + let mut mailbox = session + .mail(Request::new( tenant.as_str(), vec![RequestData::ProcRun { cmd, args }], )) .await?; - // We expect a singular response back - if res.payload.len() != 1 { - return Err(RemoteProcessError::BadResponse); - } - - // Response should be proc starting - let id = match res.payload.into_iter().next().unwrap() { - ResponseData::ProcStart { id } => id, - _ => return Err(RemoteProcessError::BadResponse), + // Wait until we get the first response, and get id from proc started + let (id, origin_id) = match mailbox.next().await { + Some(res) if res.payload.len() != 1 => { + return Err(RemoteProcessError::TransportError(TransportError::IoError( + io::Error::new(io::ErrorKind::InvalidData, "Got wrong payload size"), + ))); + } + Some(res) => { + let origin_id = res.origin_id; + match res.payload.into_iter().next().unwrap() { + ResponseData::ProcStart { id } => (id, origin_id), + x => { + return Err(RemoteProcessError::TransportError(TransportError::IoError( + io::Error::new( + io::ErrorKind::InvalidData, + format!("Got response type of {}", x.as_ref()), + ), + ))) + } + } + } + None => { + return Err(RemoteProcessError::TransportError(TransportError::IoError( + io::Error::from(io::ErrorKind::ConnectionAborted), + ))) + } }; // Create channels for our stdin/stdout/stderr - let (stdin_tx, stdin_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY); - let (stdout_tx, stdout_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY); - let (stderr_tx, stderr_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY); + let (stdin_tx, stdin_rx) = mpsc::channel(CLIENT_MAILBOX_CAPACITY); + let (stdout_tx, stdout_rx) = mpsc::channel(CLIENT_MAILBOX_CAPACITY); + let (stderr_tx, stderr_rx) = mpsc::channel(CLIENT_MAILBOX_CAPACITY); // Used to terminate request task, either explicitly by the process or internally // by the response task when it terminates @@ -100,18 +116,19 @@ impl RemoteProcess { // Now we spawn a task to handle future responses that are async // such as ProcStdout, ProcStderr, and ProcDone let kill_tx_2 = kill_tx.clone(); - let broadcast = session.broadcast.take().unwrap(); let res_task = tokio::spawn(async move { - process_incoming_responses(id, broadcast, stdout_tx, stderr_tx, kill_tx_2).await + process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2).await }); // Spawn a task that takes stdin from our channel and forwards it to the remote process + let channel = session.clone_channel(); let req_task = tokio::spawn(async move { - process_outgoing_requests(tenant, id, session, stdin_rx, kill_rx).await + process_outgoing_requests(tenant, id, channel, stdin_rx, kill_rx).await }); Ok(Self { id, + origin_id, req_task, res_task, stdin: Some(RemoteStdin(stdin_tx)), @@ -196,22 +213,18 @@ impl RemoteStderr { /// Helper function that loops, processing outgoing stdin requests to a remote process as well as /// supporting a kill request to terminate the remote process -async fn process_outgoing_requests( +async fn process_outgoing_requests( tenant: String, id: usize, - mut session: Session, + mut channel: SessionChannel, mut stdin_rx: mpsc::Receiver, mut kill_rx: mpsc::Receiver<()>, -) -> Result<(), RemoteProcessError> -where - T: DataStream, - U: Codec, -{ +) -> Result<(), RemoteProcessError> { let result = loop { tokio::select! { data = stdin_rx.recv() => { match data { - Some(data) => session.fire( + Some(data) => channel.fire( Request::new( tenant.as_str(), vec![RequestData::ProcStdin { id, data }] @@ -222,12 +235,10 @@ where } msg = kill_rx.recv() => { if msg.is_some() { - session - .fire(Request::new( - tenant.as_str(), - vec![RequestData::ProcKill { id }], - )) - .await?; + channel.fire(Request::new( + tenant.as_str(), + vec![RequestData::ProcKill { id }], + )).await?; break Ok(()); } else { break Err(RemoteProcessError::ChannelDead); @@ -243,12 +254,12 @@ where /// Helper function that loops, processing incoming stdout & stderr requests from a remote process async fn process_incoming_responses( proc_id: usize, - mut broadcast: mpsc::Receiver, + mut mailbox: Mailbox, stdout_tx: mpsc::Sender, stderr_tx: mpsc::Sender, kill_tx: mpsc::Sender<()>, ) -> Result<(bool, Option), RemoteProcessError> { - while let Some(res) = broadcast.recv().await { + while let Some(res) = mailbox.next().await { // Check if any of the payload data is the termination let exit_status = res.payload.iter().find_map(|data| match data { ResponseData::ProcDone { id, success, code } if *id == proc_id => { @@ -291,61 +302,68 @@ async fn process_incoming_responses( mod tests { use super::*; use crate::{ - data::{Error, ErrorKind}, + data::{Error, ErrorKind, Response}, net::{InmemoryStream, PlainCodec, Transport}, }; - fn make_session() -> ( - Transport, - Session, - ) { + fn make_session() -> (Transport, Session) { let (t1, t2) = Transport::make_pair(); (t1, Session::initialize(t2).unwrap()) } #[tokio::test] - async fn spawn_should_return_bad_response_if_payload_size_unexpected() { - let (mut transport, session) = make_session(); + async fn spawn_should_return_invalid_data_if_payload_size_unexpected() { + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); // Send back a response through the session transport - .send(Response::new("test-tenant", Some(req.id), Vec::new())) + .send(Response::new("test-tenant", req.id, Vec::new())) .await .unwrap(); // Get the spawn result and verify let result = spawn_task.await.unwrap(); assert!( - matches!(result, Err(RemoteProcessError::BadResponse)), + matches!( + &result, + Err(RemoteProcessError::TransportError(TransportError::IoError(x))) + if x.kind() == io::ErrorKind::InvalidData + ), "Unexpected result: {:?}", result ); } #[tokio::test] - async fn spawn_should_return_bad_response_if_did_not_get_a_indicator_that_process_started() { - let (mut transport, session) = make_session(); + async fn spawn_should_return_invalid_data_if_did_not_get_a_indicator_that_process_started() { + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); @@ -354,7 +372,7 @@ mod tests { transport .send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::Error(Error { kind: ErrorKind::Other, description: String::from("some error"), @@ -366,7 +384,11 @@ mod tests { // Get the spawn result and verify let result = spawn_task.await.unwrap(); assert!( - matches!(result, Err(RemoteProcessError::BadResponse)), + matches!( + &result, + Err(RemoteProcessError::TransportError(TransportError::IoError(x))) + if x.kind() == io::ErrorKind::InvalidData + ), "Unexpected result: {:?}", result ); @@ -374,16 +396,19 @@ mod tests { #[tokio::test] async fn kill_should_return_error_if_internal_tasks_already_completed() { - let (mut transport, session) = make_session(); + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); @@ -393,7 +418,7 @@ mod tests { transport .send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::ProcStart { id }], )) .await @@ -416,16 +441,19 @@ mod tests { #[tokio::test] async fn kill_should_send_proc_kill_request_and_then_cause_stdin_forwarding_to_close() { - let (mut transport, session) = make_session(); + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); @@ -435,7 +463,7 @@ mod tests { transport .send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::ProcStart { id }], )) .await @@ -469,16 +497,19 @@ mod tests { #[tokio::test] async fn stdin_should_be_forwarded_from_receiver_field() { - let (mut transport, session) = make_session(); + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); @@ -488,7 +519,7 @@ mod tests { transport .send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::ProcStart { id }], )) .await @@ -521,16 +552,19 @@ mod tests { #[tokio::test] async fn stdout_should_be_forwarded_to_receiver_field() { - let (mut transport, session) = make_session(); + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); @@ -540,7 +574,7 @@ mod tests { transport .send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::ProcStart { id }], )) .await @@ -552,7 +586,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + req.id, vec![ResponseData::ProcStdout { id, data: String::from("some out"), @@ -567,16 +601,19 @@ mod tests { #[tokio::test] async fn stderr_should_be_forwarded_to_receiver_field() { - let (mut transport, session) = make_session(); + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); @@ -586,7 +623,7 @@ mod tests { transport .send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::ProcStart { id }], )) .await @@ -598,7 +635,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + req.id, vec![ResponseData::ProcStderr { id, data: String::from("some err"), @@ -613,16 +650,19 @@ mod tests { #[tokio::test] async fn wait_should_return_error_if_internal_tasks_fail() { - let (mut transport, session) = make_session(); + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); @@ -632,7 +672,7 @@ mod tests { transport .send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::ProcStart { id }], )) .await @@ -652,16 +692,19 @@ mod tests { #[tokio::test] async fn wait_should_return_error_if_connection_terminates_before_receiving_done_response() { - let (mut transport, session) = make_session(); + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); @@ -671,7 +714,7 @@ mod tests { transport .send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::ProcStart { id }], )) .await @@ -679,6 +722,10 @@ mod tests { // Receive the process and then terminate session connection let proc = spawn_task.await.unwrap().unwrap(); + + // Ensure that the spawned task gets a chance to wait on stdout/stderr + tokio::task::yield_now().await; + drop(transport); // Ensure that the other tasks are cancelled before continuing @@ -694,16 +741,19 @@ mod tests { #[tokio::test] async fn receiving_done_response_should_result_in_wait_returning_exit_information() { - let (mut transport, session) = make_session(); + let (mut transport, mut session) = make_session(); // Create a task for process spawning as we need to handle the request and a response // in a separate async block - let spawn_task = tokio::spawn(RemoteProcess::spawn( - String::from("test-tenant"), - session, - String::from("cmd"), - vec![String::from("arg")], - )); + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + &mut session, + String::from("cmd"), + vec![String::from("arg")], + ) + .await + }); // Wait until we get the request from the session let req = transport.receive::().await.unwrap().unwrap(); @@ -713,7 +763,7 @@ mod tests { transport .send(Response::new( "test-tenant", - Some(req.id), + req.id, vec![ResponseData::ProcStart { id }], )) .await @@ -727,7 +777,7 @@ mod tests { transport .send(Response::new( "test-tenant", - None, + req.id, vec![ResponseData::ProcDone { id, success: false, diff --git a/core/src/client/session/ext.rs b/core/src/client/session/ext.rs new file mode 100644 index 0000000..41b3f2e --- /dev/null +++ b/core/src/client/session/ext.rs @@ -0,0 +1,435 @@ +use crate::{ + client::{RemoteProcess, RemoteProcessError, Session}, + data::{DirEntry, Error as Failure, FileType, Request, RequestData, ResponseData}, + net::TransportError, +}; +use derive_more::{Display, Error, From}; +use std::{future::Future, path::PathBuf, pin::Pin}; + +/// Represents an error that can occur related to convenience functions tied to a [`Session`] +#[derive(Debug, Display, Error, From)] +pub enum SessionExtError { + /// Occurs when the remote action fails + Failure(#[error(not(source))] Failure), + + /// Occurs when a transport error is encountered + TransportError(TransportError), + + /// Occurs when receiving a response that was not expected + MismatchedResponse, +} + +pub type AsyncReturn<'a, T, E = SessionExtError> = + Pin> + Send + 'a>>; + +/// Represents metadata about some path on a remote machine +pub struct Metadata { + pub file_type: FileType, + pub len: u64, + pub readonly: bool, + + pub canonicalized_path: Option, + + pub accessed: Option, + pub created: Option, + pub modified: Option, +} + +/// Provides convenience functions on top of a [`Session`] +pub trait SessionExt { + /// Appends to a remote file using the data from a collection of bytes + fn append_file( + &mut self, + tenant: impl Into, + path: impl Into, + data: impl Into>, + ) -> AsyncReturn<'_, ()>; + + /// Appends to a remote file using the data from a string + fn append_file_text( + &mut self, + tenant: impl Into, + path: impl Into, + data: impl Into, + ) -> AsyncReturn<'_, ()>; + + /// Copies a remote file or directory from src to dst + fn copy( + &mut self, + tenant: impl Into, + src: impl Into, + dst: impl Into, + ) -> AsyncReturn<'_, ()>; + + /// Creates a remote directory, optionally creating all parent components if specified + fn create_dir( + &mut self, + tenant: impl Into, + path: impl Into, + all: bool, + ) -> AsyncReturn<'_, ()>; + + /// Checks if a path exists on a remote machine + fn exists( + &mut self, + tenant: impl Into, + path: impl Into, + ) -> AsyncReturn<'_, bool>; + + /// Retrieves metadata about a path on a remote machine + fn metadata( + &mut self, + tenant: impl Into, + path: impl Into, + canonicalize: bool, + resolve_file_type: bool, + ) -> AsyncReturn<'_, Metadata>; + + /// Reads entries from a directory, returning a tuple of directory entries and failures + fn read_dir( + &mut self, + tenant: impl Into, + path: impl Into, + depth: usize, + absolute: bool, + canonicalize: bool, + include_root: bool, + ) -> AsyncReturn<'_, (Vec, Vec)>; + + /// Reads a remote file as a collection of bytes + fn read_file( + &mut self, + tenant: impl Into, + path: impl Into, + ) -> AsyncReturn<'_, Vec>; + + /// Returns a remote file as a string + fn read_file_text( + &mut self, + tenant: impl Into, + path: impl Into, + ) -> AsyncReturn<'_, String>; + + /// Removes a remote file or directory, supporting removal of non-empty directories if + /// force is true + fn remove( + &mut self, + tenant: impl Into, + path: impl Into, + force: bool, + ) -> AsyncReturn<'_, ()>; + + /// Renames a remote file or directory from src to dst + fn rename( + &mut self, + tenant: impl Into, + src: impl Into, + dst: impl Into, + ) -> AsyncReturn<'_, ()>; + + /// Spawns a process on the remote machine + fn spawn( + &mut self, + tenant: impl Into, + cmd: impl Into, + args: Vec, + ) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError>; + + /// Writes a remote file with the data from a collection of bytes + fn write_file( + &mut self, + tenant: impl Into, + path: impl Into, + data: impl Into>, + ) -> AsyncReturn<'_, ()>; + + /// Writes a remote file with the data from a string + fn write_file_text( + &mut self, + tenant: impl Into, + path: impl Into, + data: impl Into, + ) -> AsyncReturn<'_, ()>; +} + +macro_rules! make_body { + ($self:expr, $tenant:expr, $data:expr, @ok) => { + make_body!($self, $tenant, $data, |data| { + if data.is_ok() { + Ok(()) + } else { + Err(SessionExtError::MismatchedResponse) + } + }) + }; + + ($self:expr, $tenant:expr, $data:expr, $and_then:expr) => {{ + let req = Request::new($tenant, vec![$data]); + Box::pin(async move { + $self + .send(req) + .await + .map_err(SessionExtError::from) + .and_then(|res| { + if res.payload.len() == 1 { + Ok(res.payload.into_iter().next().unwrap()) + } else { + Err(SessionExtError::MismatchedResponse) + } + }) + .and_then($and_then) + }) + }}; +} + +impl SessionExt for Session { + /// Appends to a remote file using the data from a collection of bytes + fn append_file( + &mut self, + tenant: impl Into, + path: impl Into, + data: impl Into>, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + tenant, + RequestData::FileAppend { path: path.into(), data: data.into() }, + @ok + ) + } + + /// Appends to a remote file using the data from a string + fn append_file_text( + &mut self, + tenant: impl Into, + path: impl Into, + data: impl Into, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + tenant, + RequestData::FileAppendText { path: path.into(), text: data.into() }, + @ok + ) + } + + /// Copies a remote file or directory from src to dst + fn copy( + &mut self, + tenant: impl Into, + src: impl Into, + dst: impl Into, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + tenant, + RequestData::Copy { src: src.into(), dst: dst.into() }, + @ok + ) + } + + /// Creates a remote directory, optionally creating all parent components if specified + fn create_dir( + &mut self, + tenant: impl Into, + path: impl Into, + all: bool, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + tenant, + RequestData::DirCreate { path: path.into(), all }, + @ok + ) + } + + /// Checks if a path exists on a remote machine + fn exists( + &mut self, + tenant: impl Into, + path: impl Into, + ) -> AsyncReturn<'_, bool> { + make_body!( + self, + tenant, + RequestData::Exists { path: path.into() }, + |data| match data { + ResponseData::Exists(x) => Ok(x), + _ => Err(SessionExtError::MismatchedResponse), + } + ) + } + + /// Retrieves metadata about a path on a remote machine + fn metadata( + &mut self, + tenant: impl Into, + path: impl Into, + canonicalize: bool, + resolve_file_type: bool, + ) -> AsyncReturn<'_, Metadata> { + make_body!( + self, + tenant, + RequestData::Metadata { + path: path.into(), + canonicalize, + resolve_file_type + }, + |data| match data { + ResponseData::Metadata { + canonicalized_path, + file_type, + len, + readonly, + accessed, + created, + modified, + } => Ok(Metadata { + canonicalized_path, + file_type, + len, + readonly, + accessed, + created, + modified, + }), + _ => Err(SessionExtError::MismatchedResponse), + } + ) + } + + /// Reads entries from a directory, returning a tuple of directory entries and failures + fn read_dir( + &mut self, + tenant: impl Into, + path: impl Into, + depth: usize, + absolute: bool, + canonicalize: bool, + include_root: bool, + ) -> AsyncReturn<'_, (Vec, Vec)> { + make_body!( + self, + tenant, + RequestData::DirRead { + path: path.into(), + depth, + absolute, + canonicalize, + include_root + }, + |data| match data { + ResponseData::DirEntries { entries, errors } => Ok((entries, errors)), + _ => Err(SessionExtError::MismatchedResponse), + } + ) + } + + /// Reads a remote file as a collection of bytes + fn read_file( + &mut self, + tenant: impl Into, + path: impl Into, + ) -> AsyncReturn<'_, Vec> { + make_body!( + self, + tenant, + RequestData::FileRead { path: path.into() }, + |data| match data { + ResponseData::Blob { data } => Ok(data), + _ => Err(SessionExtError::MismatchedResponse), + } + ) + } + + /// Returns a remote file as a string + fn read_file_text( + &mut self, + tenant: impl Into, + path: impl Into, + ) -> AsyncReturn<'_, String> { + make_body!( + self, + tenant, + RequestData::FileReadText { path: path.into() }, + |data| match data { + ResponseData::Text { data } => Ok(data), + _ => Err(SessionExtError::MismatchedResponse), + } + ) + } + + /// Removes a remote file or directory, supporting removal of non-empty directories if + /// force is true + fn remove( + &mut self, + tenant: impl Into, + path: impl Into, + force: bool, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + tenant, + RequestData::Remove { path: path.into(), force }, + @ok + ) + } + + /// Renames a remote file or directory from src to dst + fn rename( + &mut self, + tenant: impl Into, + src: impl Into, + dst: impl Into, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + tenant, + RequestData::Rename { src: src.into(), dst: dst.into() }, + @ok + ) + } + + /// Spawns a process on the remote machine + fn spawn( + &mut self, + tenant: impl Into, + cmd: impl Into, + args: Vec, + ) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError> { + let tenant = tenant.into(); + let cmd = cmd.into(); + Box::pin(async move { RemoteProcess::spawn(tenant, self, cmd, args).await }) + } + + /// Writes a remote file with the data from a collection of bytes + fn write_file( + &mut self, + tenant: impl Into, + path: impl Into, + data: impl Into>, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + tenant, + RequestData::FileWrite { path: path.into(), data: data.into() }, + @ok + ) + } + + /// Writes a remote file with the data from a string + fn write_file_text( + &mut self, + tenant: impl Into, + path: impl Into, + data: impl Into, + ) -> AsyncReturn<'_, ()> { + make_body!( + self, + tenant, + RequestData::FileWriteText { path: path.into(), text: data.into() }, + @ok + ) + } +} diff --git a/core/src/client/session/mailbox.rs b/core/src/client/session/mailbox.rs new file mode 100644 index 0000000..20d44a1 --- /dev/null +++ b/core/src/client/session/mailbox.rs @@ -0,0 +1,84 @@ +use crate::{client::utils, data::Response}; +use std::{collections::HashMap, time::Duration}; +use tokio::{io, sync::mpsc}; + +pub struct PostOffice { + mailboxes: HashMap>, +} + +impl PostOffice { + pub fn new() -> Self { + Self { + mailboxes: HashMap::new(), + } + } + + /// Creates a new mailbox using the given id and buffer size for maximum messages + pub fn make_mailbox(&mut self, id: usize, buffer: usize) -> Mailbox { + let (tx, rx) = mpsc::channel(buffer); + self.mailboxes.insert(id, tx); + + Mailbox { id, rx } + } + + /// Delivers a response to appropriate mailbox, returning false if no mailbox is found + /// for the response or if the mailbox is no longer receiving responses + pub async fn deliver(&mut self, res: Response) -> bool { + let id = res.origin_id; + + let success = if let Some(tx) = self.mailboxes.get_mut(&id) { + tx.send(res).await.is_ok() + } else { + false + }; + + // If failed, we want to remove the mailbox sender as it is no longer valid + if !success { + self.mailboxes.remove(&id); + } + + success + } + + /// Removes all mailboxes from post office that are closed + pub fn prune_mailboxes(&mut self) { + self.mailboxes.retain(|_, tx| !tx.is_closed()) + } + + /// Closes out all mailboxes by removing the mailboxes delivery trackers internally + pub fn clear_mailboxes(&mut self) { + self.mailboxes.clear(); + } +} + +pub struct Mailbox { + /// Represents id associated with the mailbox + id: usize, + + /// Underlying mailbox storage + rx: mpsc::Receiver, +} + +impl Mailbox { + /// Represents id associated with the mailbox + pub fn id(&self) -> usize { + self.id + } + + /// Receives next response in mailbox + pub async fn next(&mut self) -> Option { + self.rx.recv().await + } + + /// Receives next response in mailbox, waiting up to duration before timing out + pub async fn next_timeout(&mut self, duration: Duration) -> io::Result> { + utils::timeout(duration, self.next()).await + } + + /// Closes the mailbox such that it will not receive any more responses + /// + /// Any responses already in the mailbox will still be returned via `next` + pub async fn close(&mut self) { + self.rx.close() + } +} diff --git a/core/src/client/session/mod.rs b/core/src/client/session/mod.rs index c971f2e..afeddb0 100644 --- a/core/src/client/session/mod.rs +++ b/core/src/client/session/mod.rs @@ -1,51 +1,55 @@ use crate::{ client::utils, - constants::CLIENT_BROADCAST_CHANNEL_CAPACITY, + constants::CLIENT_MAILBOX_CAPACITY, data::{Request, Response}, - net::{Codec, DataStream, Transport, TransportError, TransportWriteHalf}, + net::{Codec, DataStream, Transport, TransportError}, }; use log::*; use std::{ - collections::HashMap, convert, net::SocketAddr, - sync::{Arc, Mutex}, + ops::{Deref, DerefMut}, + sync::{Arc, Weak}, }; use tokio::{ io, net::TcpStream, - sync::{mpsc, oneshot}, + sync::{mpsc, Mutex}, task::{JoinError, JoinHandle}, time::Duration, }; +mod ext; +pub use ext::SessionExt; + mod info; pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError}; -type Callbacks = Arc>>>; +mod mailbox; +pub use mailbox::Mailbox; +use mailbox::PostOffice; /// Represents a session with a remote server that can be used to send requests & receive responses -pub struct Session -where - T: DataStream, - U: Codec, -{ - /// Underlying transport used by session - t_write: TransportWriteHalf, +pub struct Session { + /// Used to send requests to a server + channel: SessionChannel, - /// Collection of callbacks to be invoked upon receiving a response to a request - callbacks: Callbacks, + /// Contains the task that is running to send requests to a server + request_task: JoinHandle<()>, /// Contains the task that is running to receive responses from a server response_task: JoinHandle<()>, - /// Represents the receiver for broadcasted responses (ones with no callback) - pub broadcast: Option>, + /// Contains the task that runs on a timer to prune closed mailboxes + prune_task: JoinHandle<()>, } -impl Session { +impl Session { /// Connect to a remote TCP server using the provided information - pub async fn tcp_connect(addr: SocketAddr, codec: U) -> io::Result { + pub async fn tcp_connect(addr: SocketAddr, codec: U) -> io::Result + where + U: Codec + Send + 'static, + { let transport = Transport::::connect(addr, codec).await?; debug!( "Session has been established with {}", @@ -58,11 +62,14 @@ impl Session { } /// Connect to a remote TCP server, timing out after duration has passed - pub async fn tcp_connect_timeout( + pub async fn tcp_connect_timeout( addr: SocketAddr, codec: U, duration: Duration, - ) -> io::Result { + ) -> io::Result + where + U: Codec + Send + 'static, + { utils::timeout(duration, Self::tcp_connect(addr, codec)) .await .and_then(convert::identity) @@ -70,9 +77,12 @@ impl Session { } #[cfg(unix)] -impl Session { +impl Session { /// Connect to a proxy unix socket - pub async fn unix_connect(path: impl AsRef, codec: U) -> io::Result { + pub async fn unix_connect(path: impl AsRef, codec: U) -> io::Result + where + U: Codec + Send + 'static, + { let transport = Transport::::connect(path, codec).await?; debug!( "Session has been established with {}", @@ -85,53 +95,50 @@ impl Session { } /// Connect to a proxy unix socket, timing out after duration has passed - pub async fn unix_connect_timeout( + pub async fn unix_connect_timeout( path: impl AsRef, codec: U, duration: Duration, - ) -> io::Result { + ) -> io::Result + where + U: Codec + Send + 'static, + { utils::timeout(duration, Self::unix_connect(path, codec)) .await .and_then(convert::identity) } } -impl Session -where - T: DataStream, - U: Codec + Send + 'static, -{ +impl Session { /// Initializes a session using the provided transport - pub fn initialize(transport: Transport) -> io::Result { - let (mut t_read, t_write) = transport.into_split(); - let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new())); - let (broadcast_tx, broadcast_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY); - - // Start a task that continually checks for responses and triggers callbacks - let callbacks_2 = Arc::clone(&callbacks); + pub fn initialize(transport: Transport) -> io::Result + where + T: DataStream, + U: Codec + Send + 'static, + { + let (mut t_read, mut t_write) = transport.into_split(); + let post_office = Arc::new(Mutex::new(PostOffice::new())); + let weak_post_office = Arc::downgrade(&post_office); + + // Start a task that continually checks for responses and delivers them using the + // post office let response_task = tokio::spawn(async move { loop { match t_read.receive::().await { Ok(Some(res)) => { trace!("Incoming response: {:?}", res); - let maybe_callback = res - .origin_id - .as_ref() - .and_then(|id| callbacks_2.lock().unwrap().remove(id)); - - // If there is an origin to this response, trigger the callback - if let Some(tx) = maybe_callback { - trace!("Callback exists for response! Triggering!"); - if let Err(res) = tx.send(res) { - error!("Failed to trigger callback for response {}", res.id); - } - - // Otherwise, this goes into the junk draw of response handlers - } else { - trace!("Callback missing for response! Broadcasting!"); - if let Err(x) = broadcast_tx.send(res).await { - error!("Failed to trigger broadcast: {}", x); - } + let res_id = res.id; + let res_origin_id = res.origin_id; + + // Try to send response to appropriate mailbox + // NOTE: We don't log failures as errors as using fire(...) for a + // session is valid and would not have a mailbox + if !post_office.lock().await.deliver(res).await { + trace!( + "Response {} has no mailbox for origin {}", + res_id, + res_origin_id + ); } } Ok(None) => { @@ -144,47 +151,144 @@ where } } } + + // Clean up remaining mailbox senders + post_office.lock().await.clear_mailboxes(); + }); + + let (tx, mut rx) = mpsc::channel::(1); + let request_task = tokio::spawn(async move { + while let Some(req) = rx.recv().await { + if let Err(x) = t_write.send(req).await { + error!("Failed to send request to server: {}", x); + break; + } + } }); + // Create a task that runs once a minute and prunes mailboxes + let post_office = Weak::clone(&weak_post_office); + let prune_task = tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(60)).await; + if let Some(post_office) = Weak::upgrade(&post_office) { + post_office.lock().await.prune_mailboxes(); + } else { + break; + } + } + }); + + let channel = SessionChannel { + tx, + post_office: weak_post_office, + }; + Ok(Self { - t_write, - callbacks, - broadcast: Some(broadcast_rx), + channel, + request_task, response_task, + prune_task, }) } } -impl Session -where - T: DataStream, - U: Codec, -{ +impl Session { /// Waits for the session to terminate, which results when the receiving end of the network /// connection is closed (or the session is shutdown) pub async fn wait(self) -> Result<(), JoinError> { - self.response_task.await + self.prune_task.abort(); + tokio::try_join!(self.request_task, self.response_task).map(|_| ()) } /// Abort the session's current connection by forcing its response task to shutdown pub fn abort(&self) { - self.response_task.abort() + self.request_task.abort(); + self.response_task.abort(); + self.prune_task.abort(); } - /// Sends a request and waits for a response - pub async fn send(&mut self, req: Request) -> Result { - trace!("Sending request: {:?}", req); + /// Clones the underlying channel for requests and returns the cloned instance + pub fn clone_channel(&self) -> SessionChannel { + self.channel.clone() + } +} + +impl Deref for Session { + type Target = SessionChannel; + + fn deref(&self) -> &Self::Target { + &self.channel + } +} - // First, add a callback that will trigger when we get the response for this request - let (tx, rx) = oneshot::channel(); - self.callbacks.lock().unwrap().insert(req.id, tx); +impl DerefMut for Session { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.channel + } +} + +impl From for SessionChannel { + fn from(session: Session) -> Self { + session.channel + } +} + +/// Represents a sender of requests tied to a session, holding onto a weak reference of +/// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped, +/// any sent request will no longer be able to receive responses +#[derive(Clone)] +pub struct SessionChannel { + /// Used to send requests to a server + tx: mpsc::Sender, + + /// Collection of mailboxes for receiving responses to requests + post_office: Weak>, +} + +impl SessionChannel { + /// Returns true if no more requests can be transferred + pub fn is_closed(&self) -> bool { + self.tx.is_closed() + } + + /// Sends a request and returns a mailbox that can receive one or more responses, failing if + /// unable to send a request or if the session's receiving line to the remote server has + /// already been severed + pub async fn mail(&mut self, req: Request) -> Result { + trace!("Mailing request: {:?}", req); + + // First, create a mailbox using the request's id + let mailbox = Weak::upgrade(&self.post_office) + .ok_or_else(|| { + TransportError::IoError(io::Error::new( + io::ErrorKind::NotConnected, + "Session's post office is no longer available", + )) + })? + .lock() + .await + .make_mailbox(req.id, CLIENT_MAILBOX_CAPACITY); // Second, send the request - self.t_write.send(req).await?; + self.fire(req).await?; - // Third, wait for the response - rx.await - .map_err(|x| TransportError::from(io::Error::new(io::ErrorKind::ConnectionAborted, x))) + // Third, return mailbox + Ok(mailbox) + } + + /// Sends a request and waits for a response, failing if unable to send a request or if + /// the session's receiving line to the remote server has already been severed + pub async fn send(&mut self, req: Request) -> Result { + trace!("Sending request: {:?}", req); + + // Send mail and get back a mailbox + let mut mailbox = self.mail(req).await?; + + // Wait for first response, and then drop the mailbox + mailbox.next().await.ok_or_else(|| { + TransportError::IoError(io::Error::from(io::ErrorKind::ConnectionAborted)) + }) } /// Sends a request and waits for a response, timing out after duration has passed @@ -199,12 +303,14 @@ where .and_then(convert::identity) } - /// Sends a request without waiting for a response - /// - /// Any response that would be received gets sent over the broadcast channel instead + /// Sends a request without waiting for a response; this method is able to be used even + /// if the session's receiving line to the remote server has been severed pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> { trace!("Firing off request: {:?}", req); - self.t_write.send(req).await + self.tx + .send(req) + .await + .map_err(|x| TransportError::IoError(io::Error::new(io::ErrorKind::BrokenPipe, x))) } /// Sends a request without waiting for a response, timing out after duration has passed @@ -229,6 +335,40 @@ mod tests { }; use std::time::Duration; + #[tokio::test] + async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() { + let (t1, mut t2) = Transport::make_pair(); + let mut session = Session::initialize(t1).unwrap(); + + let req = Request::new(TENANT, vec![RequestData::ProcList {}]); + let res = Response::new(TENANT, req.id, vec![ResponseData::Ok]); + + let mut mailbox = session.mail(req).await.unwrap(); + + // Get first response + match tokio::join!(mailbox.next(), t2.send(res.clone())) { + (Some(actual), _) => assert_eq!(actual, res), + x => panic!("Unexpected response: {:?}", x), + } + + // Get second response + match tokio::join!(mailbox.next(), t2.send(res.clone())) { + (Some(actual), _) => assert_eq!(actual, res), + x => panic!("Unexpected response: {:?}", x), + } + + // Trigger the mailbox to wait BEFORE closing our transport to ensure that + // we don't get stuck if the mailbox was already waiting + let next_task = tokio::spawn(async move { mailbox.next().await }); + tokio::task::yield_now().await; + + drop(t2); + match next_task.await { + Ok(None) => {} + x => panic!("Unexpected response: {:?}", x), + } + } + #[tokio::test] async fn send_should_wait_until_response_received() { let (t1, mut t2) = Transport::make_pair(); @@ -237,7 +377,7 @@ mod tests { let req = Request::new(TENANT, vec![RequestData::ProcList {}]); let res = Response::new( TENANT, - Some(req.id), + req.id, vec![ResponseData::ProcEntries { entries: Vec::new(), }], diff --git a/core/src/constants.rs b/core/src/constants.rs index 8d89717..ddeb26e 100644 --- a/core/src/constants.rs +++ b/core/src/constants.rs @@ -1,6 +1,5 @@ -/// Capacity associated with a client broadcasting its received messages that -/// do not have a callback associated -pub const CLIENT_BROADCAST_CHANNEL_CAPACITY: usize = 10000; +/// Capacity associated with a client mailboxes for receiving multiple responses to a request +pub const CLIENT_MAILBOX_CAPACITY: usize = 10000; /// Represents the maximum size (in bytes) that data will be read from pipes /// per individual `read` call diff --git a/core/src/data.rs b/core/src/data.rs index 8fc7a48..7af9052 100644 --- a/core/src/data.rs +++ b/core/src/data.rs @@ -257,9 +257,9 @@ pub struct Response { /// A unique id associated with the response pub id: usize, - /// The id of the originating request, if there was one - /// (some responses are sent unprompted) - pub origin_id: Option, + /// The id of the originating request that yielded this response + /// (more than one response may have same origin) + pub origin_id: usize, /// The main payload containing a collection of data comprising one or more results pub payload: Vec, @@ -267,11 +267,7 @@ pub struct Response { impl Response { /// Creates a new response, generating a unique id for it - pub fn new( - tenant: impl Into, - origin_id: Option, - payload: Vec, - ) -> Self { + pub fn new(tenant: impl Into, origin_id: usize, payload: Vec) -> Self { let id = rand::random(); Self { tenant: tenant.into(), @@ -486,7 +482,8 @@ impl From for ResponseData { } /// General purpose error type that can be sent across the wire -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[display(fmt = "{}: {}", kind, description)] #[serde(rename_all = "snake_case", deny_unknown_fields)] pub struct Error { /// Label describing the kind of error diff --git a/core/src/lib.rs b/core/src/lib.rs index 95d405c..ada4126 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,9 +1,9 @@ mod client; pub use client::{ LspContent, LspContentParseError, LspData, LspDataParseError, LspHeader, LspHeaderParseError, - LspSessionInfoError, RemoteLspProcess, RemoteLspStderr, RemoteLspStdin, RemoteLspStdout, - RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout, Session, - SessionInfo, SessionInfoFile, SessionInfoParseError, + LspSessionInfoError, Mailbox, RemoteLspProcess, RemoteLspStderr, RemoteLspStdin, + RemoteLspStdout, RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout, + Session, SessionInfo, SessionInfoFile, SessionInfoParseError, }; mod constants; diff --git a/core/src/server/distant/handler.rs b/core/src/server/distant/handler.rs index 0390667..918dbd1 100644 --- a/core/src/server/distant/handler.rs +++ b/core/src/server/distant/handler.rs @@ -10,7 +10,9 @@ use futures::future; use log::*; use std::{ env, + future::Future, path::{Path, PathBuf}, + pin::Pin, process::Stdio, sync::Arc, time::SystemTime, @@ -22,8 +24,8 @@ use tokio::{ }; use walkdir::WalkDir; -pub type Reply = mpsc::Sender; type HState = Arc>; +type ReplyRet = Pin + Send + 'static>>; #[derive(Debug, Display, Error, From)] pub enum ServerError { @@ -60,15 +62,17 @@ pub(super) async fn process( conn_id: usize, state: HState, req: Request, - tx: Reply, + tx: mpsc::Sender, ) -> Result<(), mpsc::error::SendError> { - async fn inner( - tenant: Arc, + async fn inner( conn_id: usize, state: HState, data: RequestData, - tx: Reply, - ) -> Result { + reply: F, + ) -> Result + where + F: FnMut(Vec) -> ReplyRet + Clone + Send + 'static, + { match data { RequestData::FileRead { path } => file_read(path).await, RequestData::FileReadText { path } => file_read_text(path).await, @@ -93,9 +97,7 @@ pub(super) async fn process( canonicalize, resolve_file_type, } => metadata(path, canonicalize, resolve_file_type).await, - RequestData::ProcRun { cmd, args } => { - proc_run(tenant.to_string(), conn_id, state, tx, cmd, args).await - } + RequestData::ProcRun { cmd, args } => proc_run(conn_id, state, reply, cmd, args).await, RequestData::ProcKill { id } => proc_kill(state, id).await, RequestData::ProcStdin { id, data } => proc_stdin(state, id, data).await, RequestData::ProcList {} => proc_list(state).await, @@ -103,16 +105,24 @@ pub(super) async fn process( } } - let tenant = Arc::new(req.tenant.clone()); + let reply = { + let origin_id = req.id; + let tenant = req.tenant.clone(); + let tx_2 = tx.clone(); + move |payload: Vec| -> ReplyRet { + let tx = tx_2.clone(); + let res = Response::new(tenant.to_string(), origin_id, payload); + Box::pin(async move { tx.send(res).await.is_ok() }) + } + }; // Build up a collection of tasks to run independently let mut payload_tasks = Vec::new(); for data in req.payload { - let tenant_2 = Arc::clone(&tenant); let state_2 = Arc::clone(&state); - let tx_2 = tx.clone(); + let reply_2 = reply.clone(); payload_tasks.push(tokio::spawn(async move { - match inner(tenant_2, conn_id, state_2, data, tx_2).await { + match inner(conn_id, state_2, data, reply_2).await { Ok(outgoing) => outgoing, Err(x) => Outgoing::from(ResponseData::from(x)), } @@ -135,7 +145,7 @@ pub(super) async fn process( .collect(); let payload = outgoing.into_iter().map(|x| x.data).collect(); - let res = Response::new(req.tenant, Some(req.id), payload); + let res = Response::new(req.tenant, req.id, payload); // Send out our primary response from processing the request let result = tx.send(res).await; @@ -407,14 +417,16 @@ async fn metadata( })) } -async fn proc_run( - tenant: String, +async fn proc_run( conn_id: usize, state: HState, - tx: Reply, + reply: F, cmd: String, args: Vec, -) -> Result { +) -> Result +where + F: FnMut(Vec) -> ReplyRet + Clone + Send + 'static, +{ let id = rand::random(); let mut child = Command::new(cmd.to_string()) @@ -430,21 +442,16 @@ async fn proc_run( let post_hook = Box::new(move |mut state_lock: MutexGuard<'_, State>| { // Spawn a task that sends stdout as a response - let tx_2 = tx.clone(); - let tenant_2 = tenant.clone(); let mut stdout = child.stdout.take().unwrap(); + let mut reply_2 = reply.clone(); let stdout_task = tokio::spawn(async move { let mut buf: [u8; MAX_PIPE_CHUNK_SIZE] = [0; MAX_PIPE_CHUNK_SIZE]; loop { match stdout.read(&mut buf).await { Ok(n) if n > 0 => match String::from_utf8(buf[..n].to_vec()) { Ok(data) => { - let res = Response::new( - tenant_2.as_str(), - None, - vec![ResponseData::ProcStdout { id, data }], - ); - if tx_2.send(res).await.is_err() { + let payload = vec![ResponseData::ProcStdout { id, data }]; + if !reply_2(payload).await { error!(" Stdout channel closed", conn_id); break; } @@ -468,21 +475,16 @@ async fn proc_run( }); // Spawn a task that sends stderr as a response - let tx_2 = tx.clone(); - let tenant_2 = tenant.clone(); let mut stderr = child.stderr.take().unwrap(); + let mut reply_2 = reply.clone(); let stderr_task = tokio::spawn(async move { let mut buf: [u8; MAX_PIPE_CHUNK_SIZE] = [0; MAX_PIPE_CHUNK_SIZE]; loop { match stderr.read(&mut buf).await { Ok(n) if n > 0 => match String::from_utf8(buf[..n].to_vec()) { Ok(data) => { - let res = Response::new( - tenant_2.as_str(), - None, - vec![ResponseData::ProcStderr { id, data }], - ); - if tx_2.send(res).await.is_err() { + let payload = vec![ResponseData::ProcStderr { id, data }]; + if !reply_2(payload).await { error!(" Stderr channel closed", conn_id); break; } @@ -524,6 +526,7 @@ async fn proc_run( // kill the process when triggered let state_2 = Arc::clone(&state); let (kill_tx, kill_rx) = oneshot::channel(); + let mut reply_2 = reply.clone(); let wait_task = tokio::spawn(async move { tokio::select! { status = child.wait() => { @@ -547,12 +550,8 @@ async fn proc_run( Ok(status) => { let success = status.success(); let code = status.code(); - let res = Response::new( - tenant.as_str(), - None, - vec![ResponseData::ProcDone { id, success, code }] - ); - if tx.send(res).await.is_err() { + let payload = vec![ResponseData::ProcDone { id, success, code }]; + if !reply_2(payload).await { error!( " Failed to send done for process {}!", conn_id, @@ -561,8 +560,8 @@ async fn proc_run( } } Err(x) => { - let res = Response::new(tenant.as_str(), None, vec![ResponseData::from(x)]); - if tx.send(res).await.is_err() { + let payload = vec![ResponseData::from(x)]; + if !reply_2(payload).await { error!( " Failed to send error for waiting on process {}!", conn_id, @@ -594,10 +593,8 @@ async fn proc_run( state_2.lock().await.remove_process(conn_id, id); - let res = Response::new(tenant.as_str(), None, vec![ResponseData::ProcDone { - id, success: false, code: None - }]); - if tx.send(res).await.is_err() { + let payload = vec![ResponseData::ProcDone { id, success: false, code: None }]; + if !reply_2(payload).await { error!(" Failed to send done for process {}!", conn_id, id); } } diff --git a/core/src/server/relay.rs b/core/src/server/relay.rs index 528eae8..17ea684 100644 --- a/core/src/server/relay.rs +++ b/core/src/server/relay.rs @@ -1,16 +1,15 @@ use crate::{ - client::Session, - constants::CLIENT_BROADCAST_CHANNEL_CAPACITY, - data::{Request, RequestData, Response, ResponseData}, - net::{Codec, DataStream, Transport, TransportReadHalf, TransportWriteHalf}, + client::{Session, SessionChannel}, + data::{Request, RequestData, ResponseData}, + net::{Codec, DataStream, Transport}, server::utils::{ConnTracker, ShutdownTask}, }; use futures::stream::{Stream, StreamExt}; use log::*; use std::{collections::HashMap, marker::Unpin, sync::Arc}; use tokio::{ - io::{self, AsyncRead, AsyncWrite}, - sync::{mpsc, oneshot, Mutex}, + io, + sync::{oneshot, Mutex}, task::{JoinError, JoinHandle}, time::Duration, }; @@ -19,106 +18,33 @@ use tokio::{ /// actual server pub struct RelayServer { accept_task: JoinHandle<()>, - broadcast_task: JoinHandle<()>, - forward_task: JoinHandle<()>, conns: Arc>>, } impl RelayServer { - pub fn initialize( - mut session: Session, + pub fn initialize( + session: Session, mut stream: S, shutdown_after: Option, ) -> io::Result where - T1: DataStream + 'static, - T2: DataStream + Send + 'static, - U1: Codec + Send + 'static, - U2: Codec + Send + 'static, - S: Stream> + Send + Unpin + 'static, + T: DataStream + Send + 'static, + U: Codec + Send + 'static, + S: Stream> + Send + Unpin + 'static, { let conns: Arc>> = Arc::new(Mutex::new(HashMap::new())); - // Spawn task to send server responses to the appropriate connections - let conns_2 = Arc::clone(&conns); - debug!("Spawning response broadcast task"); - let mut broadcast = session.broadcast.take().unwrap(); - let (shutdown_broadcast_tx, mut shutdown_broadcast_rx) = mpsc::channel::<()>(1); - let broadcast_task = tokio::spawn(async move { - loop { - let res = tokio::select! { - maybe_res = broadcast.recv() => { - match maybe_res { - Some(res) => res, - None => break, - } - } - _ = shutdown_broadcast_rx.recv() => { - break; - } - }; - - // Search for all connections with a tenant that matches the response's tenant - for conn in conns_2.lock().await.values_mut() { - if conn.state.lock().await.tenant.as_deref() == Some(res.tenant.as_str()) { - debug!( - "Forwarding response of type{} {} to connection {}", - if res.payload.len() > 1 { "s" } else { "" }, - res.to_payload_type_string(), - conn.id - ); - if let Err(x) = conn.forward_response(res).await { - error!("Failed to pass forwarding message: {}", x); - } - - // NOTE: We assume that tenant is unique, so we can break after - // forwarding the message to the first matching tenant - break; - } - } - } - }); - - // Spawn task to send to the server requests from connections - debug!("Spawning request forwarding task"); - let (req_tx, mut req_rx) = mpsc::channel::(CLIENT_BROADCAST_CHANNEL_CAPACITY); - let (shutdown_forward_tx, mut shutdown_forward_rx) = mpsc::channel::<()>(1); - let forward_task = tokio::spawn(async move { - loop { - let req = tokio::select! { - maybe_req = req_rx.recv() => { - match maybe_req { - Some(req) => req, - None => break, - } - } - _ = shutdown_forward_rx.recv() => { - break; - } - }; - - debug!( - "Forwarding request of type{} {} to server", - if req.payload.len() > 1 { "s" } else { "" }, - req.to_payload_type_string() - ); - if let Err(x) = session.fire(req).await { - error!("Session failed to send request: {:?}", x); - break; - } - } - }); - let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after); let conns_2 = Arc::clone(&conns); let accept_task = tokio::spawn(async move { let inner = async move { loop { + let channel = session.clone_channel(); match stream.next().await { Some(transport) => { let result = Conn::initialize( transport, - req_tx.clone(), + channel, tracker.as_ref().map(Arc::clone), ) .await; @@ -149,34 +75,19 @@ impl RelayServer { }, None => inner.await, } - - // Doesn't matter if we send or drop these as long as they persist until this - // task is completed, so just drop - drop(shutdown_broadcast_tx); - drop(shutdown_forward_tx); }); - Ok(Self { - accept_task, - broadcast_task, - forward_task, - conns, - }) + Ok(Self { accept_task, conns }) } /// Waits for the server to terminate pub async fn wait(self) -> Result<(), JoinError> { - match tokio::try_join!(self.accept_task, self.broadcast_task, self.forward_task) { - Ok(_) => Ok(()), - Err(x) => Err(x), - } + self.accept_task.await } /// Aborts the server by aborting the internal tasks and current connections pub async fn abort(&self) { self.accept_task.abort(); - self.broadcast_task.abort(); - self.forward_task.abort(); self.conns .lock() .await @@ -187,24 +98,13 @@ impl RelayServer { struct Conn { id: usize, - req_task: JoinHandle<()>, - res_task: JoinHandle<()>, - _cleanup_task: JoinHandle<()>, - res_tx: mpsc::Sender, - state: Arc>, -} - -/// Represents state associated with a connection -#[derive(Default)] -struct ConnState { - tenant: Option, - processes: Vec, + conn_task: JoinHandle<()>, } impl Conn { pub async fn initialize( transport: Transport, - req_tx: mpsc::Sender, + channel: SessionChannel, ct: Option>>, ) -> io::Result where @@ -215,59 +115,14 @@ impl Conn { // is not guaranteed to have an identifiable string let id: usize = rand::random(); - let (t_read, t_write) = transport.into_split(); - - // Used to alert our response task of the connection's tenant name - // based on the first - let (tenant_tx, tenant_rx) = oneshot::channel(); - - // Create a state we use to keep track of connection-specific data - debug!(" Initializing internal state", id); - let state = Arc::new(Mutex::new(ConnState::default())); - // Mark that we have a new connection if let Some(ct) = ct.as_ref() { ct.lock().await.increment(); } - // Spawn task to continually receive responses from the session that - // may or may not be relevant to the connection, which will filter - // by tenant and then along any response that matches - let (res_tx, res_rx) = mpsc::channel::(CLIENT_BROADCAST_CHANNEL_CAPACITY); - let (res_task_tx, res_task_rx) = oneshot::channel(); - let state_2 = Arc::clone(&state); - let res_task = tokio::spawn(async move { - handle_conn_outgoing(id, state_2, t_write, tenant_rx, res_rx).await; - let _ = res_task_tx.send(()); - }); - - // Spawn task to continually read requests from connection and forward - // them along to be sent via the session - let req_tx = req_tx.clone(); - let (req_task_tx, req_task_rx) = oneshot::channel(); - let state_2 = Arc::clone(&state); - let req_task = tokio::spawn(async move { - handle_conn_incoming(id, state_2, t_read, tenant_tx, req_tx).await; - let _ = req_task_tx.send(()); - }); + let conn_task = spawn_conn_handler(id, transport, channel, ct).await; - let _cleanup_task = tokio::spawn(async move { - let _ = tokio::join!(req_task_rx, res_task_rx); - - if let Some(ct) = ct.as_ref() { - ct.lock().await.decrement(); - } - debug!(" Disconnected", id); - }); - - Ok(Self { - id, - req_task, - res_task, - _cleanup_task, - res_tx, - state, - }) + Ok(Self { id, conn_task }) } /// Id associated with the connection @@ -277,153 +132,125 @@ impl Conn { /// Aborts the connection from the server side pub fn abort(&self) { - // NOTE: We don't abort the cleanup task as that needs to actually happen - // and will even if these tasks are aborted - self.req_task.abort(); - self.res_task.abort(); - } - - /// Forwards a response back through this connection - pub async fn forward_response( - &mut self, - res: Response, - ) -> Result<(), mpsc::error::SendError> { - self.res_tx.send(res).await + self.conn_task.abort(); } } -/// Conn::Request -> Session::Fire -async fn handle_conn_incoming( +async fn spawn_conn_handler( conn_id: usize, - state: Arc>, - mut reader: TransportReadHalf, - tenant_tx: oneshot::Sender, - req_tx: mpsc::Sender, -) where - T: AsyncRead + Unpin, - U: Codec, + transport: Transport, + mut channel: SessionChannel, + ct: Option>>, +) -> JoinHandle<()> +where + T: DataStream, + U: Codec + Send + 'static, { - macro_rules! process_req { - ($on_success:expr; $done:expr) => { - match reader.receive::().await { - Ok(Some(req)) => { - $on_success(&req); - if let Err(x) = req_tx.send(req).await { - error!( - "Failed to pass along request received on unix socket: {:?}", - x - ); - $done; - } - } - Ok(None) => $done, - Err(x) => { - error!("Failed to receive request from unix stream: {:?}", x); - $done; - } + let (mut t_reader, t_writer) = transport.into_split(); + let processes = Arc::new(Mutex::new(Vec::new())); + let t_writer = Arc::new(Mutex::new(t_writer)); + + let (done_tx, done_rx) = oneshot::channel(); + let mut channel_2 = channel.clone(); + let processes_2 = Arc::clone(&processes); + let task = tokio::spawn(async move { + loop { + if channel_2.is_closed() { + break; } - }; - } - let mut tenant = None; + // For each request, forward it through the session and monitor all responses + match t_reader.receive::().await { + Ok(Some(req)) => match channel_2.mail(req).await { + Ok(mut mailbox) => { + let processes = Arc::clone(&processes_2); + let t_writer = Arc::clone(&t_writer); + tokio::spawn(async move { + while let Some(res) = mailbox.next().await { + // Keep track of processes that are started so we can kill them + // when we're done + { + let mut p_lock = processes.lock().await; + for data in res.payload.iter() { + if let ResponseData::ProcStart { id } = *data { + p_lock.push(id); + } + } + } - // NOTE: Have to acquire our first request outside our loop since the oneshot - // sender of the tenant's name is consuming - process_req!( - |req: &Request| { - tenant = Some(req.tenant.clone()); - if let Err(x) = tenant_tx.send(req.tenant.clone()) { - error!("Failed to send along acquired tenant name: {:?}", x); - return; + if let Err(x) = t_writer.lock().await.send(res).await { + error!( + " Failed to send response back: {}", + conn_id, x + ); + } + } + }); + } + Err(x) => error!( + " Failed to pass along request received on unix socket: {:?}", + conn_id, x + ), + }, + Ok(None) => break, + Err(x) => { + error!( + " Failed to receive request from unix stream: {:?}", + conn_id, x + ); + break; + } } - }; - return - ); - - // Loop and process all additional requests - loop { - process_req!(|_| {}; break); - } + } - // At this point, we have processed at least one request successfully - // and should have the tenant populated. If we had a failure at the - // beginning, we exit the function early via return. - let tenant = tenant.unwrap(); + let _ = done_tx.send(()); + }); // Perform cleanup if done by sending a request to kill each running process - // debug!("Cleaning conn {} :: killing process {}", conn_id, id); - if let Err(x) = req_tx - .send(Request::new( - tenant.clone(), - state - .lock() - .await - .processes - .iter() - .map(|id| RequestData::ProcKill { id: *id }) - .collect(), - )) - .await - { - error!(" Failed to send kill signals: {}", conn_id, x); - } -} + tokio::spawn(async move { + let _ = done_rx.await; -async fn handle_conn_outgoing( - conn_id: usize, - state: Arc>, - mut writer: TransportWriteHalf, - tenant_rx: oneshot::Receiver, - mut res_rx: mpsc::Receiver, -) where - T: AsyncWrite + Unpin, - U: Codec, -{ - // We wait for the tenant to be identified by the first request - // before processing responses to be sent back; this is easier - // to implement and yields the same result as we would be dropping - // all responses before we know the tenant - if let Ok(tenant) = tenant_rx.await { - debug!("Associated tenant {} with conn {}", tenant, conn_id); - state.lock().await.tenant = Some(tenant.clone()); - - while let Some(res) = res_rx.recv().await { - debug!( - "Conn {} being sent response of type{} {}", + let p_lock = processes.lock().await; + if !p_lock.is_empty() { + trace!( + "Cleaning conn {} :: killing {} process", conn_id, - if res.payload.len() > 1 { "s" } else { "" }, - res.to_payload_type_string(), + p_lock.len() ); - - // If a new process was started, we want to capture the id and - // associate it with the connection - let ids = res.payload.iter().filter_map(|x| match x { - ResponseData::ProcStart { id } => Some(*id), - _ => None, - }); - for id in ids { - debug!("Tracking proc {} for conn {}", id, conn_id); - state.lock().await.processes.push(id); + if let Err(x) = channel + .fire(Request::new( + "relay", + p_lock + .iter() + .map(|id| RequestData::ProcKill { id: *id }) + .collect(), + )) + .await + { + error!(" Failed to send kill signals: {}", conn_id, x); } + } - if let Err(x) = writer.send(res).await { - error!("Failed to send response through unix connection: {}", x); - break; - } + if let Some(ct) = ct.as_ref() { + ct.lock().await.decrement(); } - } + debug!(" Disconnected", conn_id); + }); + + task } #[cfg(test)] mod tests { use super::*; - use crate::net::{InmemoryStream, PlainCodec}; + use crate::{ + data::Response, + net::{InmemoryStream, PlainCodec}, + }; use std::{pin::Pin, time::Duration}; + use tokio::sync::mpsc; - fn make_session() -> ( - Transport, - Session, - ) { + fn make_session() -> (Transport, Session) { let (t1, t2) = Transport::make_pair(); (t1, Session::initialize(t2).unwrap()) } @@ -519,11 +346,16 @@ mod tests { // Clear out the transport channel (outbound of session) // NOTE: Because our test stream uses a buffer size of 1, we have to clear out the // outbound data from the earlier requests before we can send back a response - let _ = transport.receive::().await.unwrap().unwrap(); - let _ = transport.receive::().await.unwrap().unwrap(); + let req_1 = transport.receive::().await.unwrap().unwrap(); + let req_2 = transport.receive::().await.unwrap().unwrap(); + let origin_id = if req_1.tenant == "test-tenant-2" { + req_1.id + } else { + req_2.id + }; // Send a response back to a singular connection based on the tenant - let res = Response::new("test-tenant-2", None, vec![ResponseData::Ok]); + let res = Response::new("test-tenant-2", origin_id, vec![ResponseData::Ok]); transport.send(res.clone()).await.unwrap(); // Verify that response is only received by a singular connection diff --git a/src/exit.rs b/src/exit.rs index 5196d28..78feb99 100644 --- a/src/exit.rs +++ b/src/exit.rs @@ -106,12 +106,11 @@ impl ExitCodeError for TransportError { impl ExitCodeError for RemoteProcessError { fn is_silent(&self) -> bool { - matches!(self, Self::BadResponse) + true } fn to_exit_code(&self) -> ExitCode { match self { - Self::BadResponse => ExitCode::DataErr, Self::ChannelDead => ExitCode::Unavailable, Self::TransportError(x) => x.to_exit_code(), Self::UnexpectedEof => ExitCode::IoError, diff --git a/src/lib.rs b/src/lib.rs index a45bd21..a546d13 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,15 +38,7 @@ fn init_logging(opt: &opt::CommonOpt, is_remote_process: bool) -> flexi_logger:: // For each module, configure logging for module in modules { - builder.module( - module, - match opt.verbose { - 0 => LevelFilter::Warn, - 1 => LevelFilter::Info, - 2 => LevelFilter::Debug, - _ => LevelFilter::Trace, - }, - ); + builder.module(module, opt.log_level.to_log_level_filter()); // If quiet, we suppress all logging output // diff --git a/src/opt.rs b/src/opt.rs index e45c725..24f70cf 100644 --- a/src/opt.rs +++ b/src/opt.rs @@ -40,17 +40,58 @@ impl Opt { } } +#[derive( + Copy, + Clone, + Debug, + Display, + PartialEq, + Eq, + IsVariant, + IntoStaticStr, + EnumString, + EnumVariantNames, +)] +#[strum(serialize_all = "snake_case")] +pub enum LogLevel { + Off, + Error, + Warn, + Info, + Debug, + Trace, +} + +impl LogLevel { + pub fn to_log_level_filter(self) -> log::LevelFilter { + match self { + Self::Off => log::LevelFilter::Off, + Self::Error => log::LevelFilter::Error, + Self::Warn => log::LevelFilter::Warn, + Self::Info => log::LevelFilter::Info, + Self::Debug => log::LevelFilter::Debug, + Self::Trace => log::LevelFilter::Trace, + } + } +} + /// Contains options that are common across subcommands #[derive(Debug, StructOpt)] pub struct CommonOpt { - /// Verbose mode (-v, -vv, -vvv, etc.) - #[structopt(short, long, parse(from_occurrences), global = true)] - pub verbose: u8, - - /// Quiet mode, suppresses all logging + /// Quiet mode, suppresses all logging (shortcut for log level off) #[structopt(short, long, global = true)] pub quiet: bool, + /// Log level to use throughout the application + #[structopt( + long, + global = true, + case_insensitive = true, + default_value = LogLevel::Info.into(), + possible_values = LogLevel::VARIANTS + )] + pub log_level: LogLevel, + /// Log output to disk instead of stderr #[structopt(long, global = true)] pub log_file: Option, diff --git a/src/session.rs b/src/session.rs index 7a54cba..279f834 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,35 +1,38 @@ use crate::{ buf::StringBuf, constants::MAX_PIPE_CHUNK_SIZE, opt::Format, output::ResponseOut, stdin, }; -use distant_core::{Codec, DataStream, Request, RequestData, Response, Session}; +use distant_core::{Mailbox, Request, RequestData, Session}; use log::*; -use std::{io, thread}; +use std::io; use structopt::StructOpt; -use tokio::{sync::mpsc, task::JoinHandle}; +use tokio::{ + sync::{mpsc, oneshot}, + task::JoinHandle, +}; /// Represents a wrapper around a session that provides CLI functionality such as reading from /// stdin and piping results back out to stdout pub struct CliSession { - _stdin_thread: thread::JoinHandle<()>, req_task: JoinHandle<()>, - res_task: JoinHandle>, } impl CliSession { - pub fn new(tenant: String, mut session: Session, format: Format) -> Self - where - T: DataStream + 'static, - U: Codec + Send + 'static, - { - let (stdin_thread, stdin_rx) = stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE); + /// Creates a new instance of a session for use in CLI interactions being fed input using + /// the program's stdin + pub fn new_for_stdin(tenant: String, session: Session, format: Format) -> Self { + let (_stdin_thread, stdin_rx) = stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE); - let (exit_tx, exit_rx) = mpsc::channel(1); - let broadcast = session.broadcast.take().unwrap(); - let res_task = - tokio::spawn( - async move { process_incoming_responses(broadcast, format, exit_rx).await }, - ); + Self::new(tenant, session, format, stdin_rx) + } + /// Creates a new instance of a session for use in CLI interactions being fed input using + /// the provided receiver + pub fn new( + tenant: String, + session: Session, + format: Format, + stdin_rx: mpsc::Receiver, + ) -> Self { let map_line = move |line: &str| match format { Format::Json => serde_json::from_str(line) .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x)), @@ -44,61 +47,53 @@ impl CliSession { } }; let req_task = tokio::spawn(async move { - process_outgoing_requests(session, stdin_rx, exit_tx, format, map_line).await + process_outgoing_requests(session, stdin_rx, format, map_line).await }); - Self { - _stdin_thread: stdin_thread, - req_task, - res_task, - } + Self { req_task } } /// Wait for the cli session to terminate pub async fn wait(self) -> io::Result<()> { - match tokio::try_join!(self.req_task, self.res_task) { - Ok((_, res)) => res, + match self.req_task.await { + Ok(res) => Ok(res), Err(x) => Err(io::Error::new(io::ErrorKind::BrokenPipe, x)), } } } -/// Helper function that loops, processing incoming responses not tied to a request to be sent out -/// over stdout/stderr -async fn process_incoming_responses( - mut broadcast: mpsc::Receiver, - format: Format, - mut exit: mpsc::Receiver<()>, -) -> io::Result<()> { - loop { - tokio::select! { - res = broadcast.recv() => { - match res { - Some(res) => ResponseOut::new(format, res)?.print(), - None => return Ok(()), +/// Helper function that loops, processing incoming responses to a mailbox +async fn process_mailbox(mut mailbox: Mailbox, format: Format, exit: oneshot::Receiver<()>) { + let inner = async move { + while let Some(res) = mailbox.next().await { + match ResponseOut::new(format, res) { + Ok(out) => out.print(), + Err(x) => { + error!("{}", x); + break; } } - _ = exit.recv() => { - return Ok(()); - } } + }; + + tokio::select! { + _ = inner => {} + _ = exit => {} } } /// Helper function that loops, processing outgoing requests created from stdin, and printing out /// responses -async fn process_outgoing_requests( - mut session: Session, +async fn process_outgoing_requests( + mut session: Session, mut stdin_rx: mpsc::Receiver, - exit_tx: mpsc::Sender<()>, format: Format, map_line: F, ) where - T: DataStream, - U: Codec, F: Fn(&str) -> io::Result, { let mut buf = StringBuf::new(); + let mut mailbox_exits = Vec::new(); while let Some(data) = stdin_rx.recv().await { // Update our buffer with the new data and split it into concrete lines and remainder @@ -112,21 +107,31 @@ async fn process_outgoing_requests( trace!("Processing line: {:?}", line); if line.is_empty() { continue; - } else if line == "exit" { - debug!("Got exit request, so closing cli session"); - stdin_rx.close(); - if exit_tx.send(()).await.is_err() { - error!("Failed to close cli session"); - } - continue; } match map_line(line) { - Ok(req) => match session.send(req).await { - Ok(res) => match ResponseOut::new(format, res) { - Ok(out) => out.print(), - Err(x) => error!("Failed to format response: {}", x), - }, + Ok(req) => match session.mail(req).await { + Ok(mut mailbox) => { + // Wait to get our first response before moving on to the next line + // of input + if let Some(res) = mailbox.next().await { + // Convert to response to output, and when successful launch + // a handler for continued responses to the same request + // such as with processes + match ResponseOut::new(format, res) { + Ok(out) => { + out.print(); + + let (tx, rx) = oneshot::channel(); + mailbox_exits.push(tx); + tokio::spawn(process_mailbox(mailbox, format, rx)); + } + Err(x) => { + error!("{}", x); + } + } + } + } Err(x) => { error!("Failed to send request: {}", x) } @@ -138,4 +143,9 @@ async fn process_outgoing_requests( } } } + + // Close out any dangling mailbox handlers + for tx in mailbox_exits { + let _ = tx.send(()); + } } diff --git a/src/subcommand/action.rs b/src/subcommand/action.rs index 269dabd..ef826d8 100644 --- a/src/subcommand/action.rs +++ b/src/subcommand/action.rs @@ -8,8 +8,7 @@ use crate::{ }; use derive_more::{Display, Error, From}; use distant_core::{ - Codec, DataStream, LspData, RemoteProcess, RemoteProcessError, Request, RequestData, Session, - TransportError, + LspData, RemoteProcess, RemoteProcessError, Request, RequestData, Session, TransportError, }; use tokio::{io, time::Duration}; @@ -66,23 +65,20 @@ async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> { ) } -async fn start( +async fn start( cmd: ActionSubcommand, - mut session: Session, + mut session: Session, timeout: Duration, lsp_data: Option, -) -> Result<(), Error> -where - T: DataStream + 'static, - U: Codec + Send + 'static, -{ +) -> Result<(), Error> { let is_shell_format = matches!(cmd.format, Format::Shell); match (cmd.interactive, cmd.operation) { // ProcRun request w/ shell format is specially handled and we ignore interactive as // the stdin will be used for sending ProcStdin to remote process (_, Some(RequestData::ProcRun { cmd, args })) if is_shell_format => { - let mut proc = RemoteProcess::spawn(utils::new_tenant(), session, cmd, args).await?; + let mut proc = + RemoteProcess::spawn(utils::new_tenant(), &mut session, cmd, args).await?; // If we also parsed an LSP's initialize request for its session, we want to forward // it along in the case of a process call @@ -97,6 +93,12 @@ where proc.stderr.take().unwrap(), ); + // Drop main session as the singular remote process will now manage stdin/stdout/stderr + // NOTE: Without this, severing stdin when from this side would not occur as we would + // continue to maintain a second reference to the remote connection's input + // through the primary session + drop(session); + let (success, exit_code) = proc.wait().await?; // Shut down our link @@ -145,7 +147,7 @@ where // Enter into CLI session where we receive requests on stdin and send out // over stdout/stderr - let cli_session = CliSession::new(utils::new_tenant(), session, cmd.format); + let cli_session = CliSession::new_for_stdin(utils::new_tenant(), session, cmd.format); cli_session.wait().await?; Ok(()) diff --git a/src/subcommand/launch.rs b/src/subcommand/launch.rs index c70bd84..396b474 100644 --- a/src/subcommand/launch.rs +++ b/src/subcommand/launch.rs @@ -129,7 +129,7 @@ async fn keep_loop(info: SessionInfo, format: Format, duration: Duration) -> io: let codec = XChaCha20Poly1305Codec::from(info.key); match Session::tcp_connect_timeout(addr, codec, duration).await { Ok(session) => { - let cli_session = CliSession::new(utils::new_tenant(), session, format); + let cli_session = CliSession::new_for_stdin(utils::new_tenant(), session, format); cli_session.wait().await } Err(x) => Err(x), diff --git a/src/subcommand/lsp.rs b/src/subcommand/lsp.rs index 987f3d6..25b0426 100644 --- a/src/subcommand/lsp.rs +++ b/src/subcommand/lsp.rs @@ -5,7 +5,7 @@ use crate::{ utils, }; use derive_more::{Display, Error, From}; -use distant_core::{Codec, DataStream, LspData, RemoteLspProcess, RemoteProcessError, Session}; +use distant_core::{LspData, RemoteLspProcess, RemoteProcessError, Session}; use tokio::io; #[derive(Debug, Display, Error, From)] @@ -53,16 +53,13 @@ async fn run_async(cmd: LspSubcommand, opt: CommonOpt) -> Result<(), Error> { ) } -async fn start( +async fn start( cmd: LspSubcommand, - session: Session, + mut session: Session, lsp_data: Option, -) -> Result<(), Error> -where - T: DataStream + 'static, - U: Codec + Send + 'static, -{ - let mut proc = RemoteLspProcess::spawn(utils::new_tenant(), session, cmd.cmd, cmd.args).await?; +) -> Result<(), Error> { + let mut proc = + RemoteLspProcess::spawn(utils::new_tenant(), &mut session, cmd.cmd, cmd.args).await?; // If we also parsed an LSP's initialize request for its session, we want to forward // it along in the case of a process call diff --git a/tests/cli/action/proc_run.rs b/tests/cli/action/proc_run.rs index f57f4aa..e9a6be2 100644 --- a/tests/cli/action/proc_run.rs +++ b/tests/cli/action/proc_run.rs @@ -327,7 +327,6 @@ fn should_support_json_to_forward_stdin_to_remote_process(ctx: &'_ DistantServer let mut child = distant_subcommand(ctx, "action") .args(&["--format", "json"]) .arg("--interactive") - .args(&["--log-file", "/tmp/test.log", "-vvv"]) .spawn() .unwrap();