diff --git a/distant-core/src/net/transport/inmemory.rs b/distant-core/src/net/transport/inmemory.rs index 9288c16..039a56f 100644 --- a/distant-core/src/net/transport/inmemory.rs +++ b/distant-core/src/net/transport/inmemory.rs @@ -5,7 +5,7 @@ use std::{ }; use tokio::{ io::{self, AsyncRead, AsyncWrite, ReadBuf}, - sync::mpsc, + sync::mpsc::{self, error::TrySendError}, }; /// Represents a data stream comprised of two inmemory channels @@ -147,7 +147,8 @@ impl AsyncWrite for InmemoryStreamWriteHalf { ) -> Poll> { match self.0.try_send(buf.to_vec()) { Ok(_) => Poll::Ready(Ok(buf.len())), - Err(_) => Poll::Ready(Ok(0)), + Err(TrySendError::Full(_)) => Poll::Pending, + Err(TrySendError::Closed(_)) => Poll::Ready(Ok(0)), } } @@ -309,4 +310,233 @@ mod tests { assert_eq!(rx.recv().await, None, "Unexpectedly got more data"); } + + #[tokio::test] + async fn read_half_should_fail_if_buf_has_no_space_remaining() { + let (_tx, _rx, stream) = InmemoryStream::make(1); + let (mut t_read, _t_write) = stream.into_split(); + + let mut buf = [0u8; 0]; + match t_read.read(&mut buf).await { + Err(x) if x.kind() == io::ErrorKind::Other => {} + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_update_buf_with_all_overflow_from_last_read_if_it_all_fits() { + let (tx, _rx, stream) = InmemoryStream::make(1); + let (mut t_read, _t_write) = stream.into_split(); + + tx.send(vec![1, 2, 3]).await.expect("Failed to send"); + + let mut buf = [0u8; 2]; + + // First, read part of the data (first two bytes) + match t_read.read(&mut buf).await { + Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]), + x => panic!("Unexpected result: {:?}", x), + } + + // Second, we send more data because the last message was placed in overflow + tx.send(vec![4, 5, 6]).await.expect("Failed to send"); + + // Third, read remainder of the overflow from first message (third byte) + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[3]), + x => panic!("Unexpected result: {:?}", x), + } + + // Fourth, verify that we start to receive the next overflow + match t_read.read(&mut buf).await { + Ok(n) if n == 2 => assert_eq!(&buf[..n], &[4, 5]), + x => panic!("Unexpected result: {:?}", x), + } + + // Fifth, verify that we get the last bit of overflow + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[6]), + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_update_buf_with_some_of_overflow_that_can_fit() { + let (tx, _rx, stream) = InmemoryStream::make(1); + let (mut t_read, _t_write) = stream.into_split(); + + tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); + + let mut buf = [0u8; 2]; + + // First, read part of the data (first two bytes) + match t_read.read(&mut buf).await { + Ok(n) if n == 2 => assert_eq!(&buf[..n], &[1, 2]), + x => panic!("Unexpected result: {:?}", x), + } + + // Second, we send more data because the last message was placed in overflow + tx.send(vec![6]).await.expect("Failed to send"); + + // Third, read next chunk of the overflow from first message (next two byte) + match t_read.read(&mut buf).await { + Ok(n) if n == 2 => assert_eq!(&buf[..n], &[3, 4]), + x => panic!("Unexpected result: {:?}", x), + } + + // Fourth, read last chunk of the overflow from first message (fifth byte) + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[5]), + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_update_buf_with_all_of_inner_channel_when_it_fits() { + let (tx, _rx, stream) = InmemoryStream::make(1); + let (mut t_read, _t_write) = stream.into_split(); + + let mut buf = [0u8; 5]; + + tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); + + // First, read all of data that fits exactly + match t_read.read(&mut buf).await { + Ok(n) if n == 5 => assert_eq!(&buf[..n], &[1, 2, 3, 4, 5]), + x => panic!("Unexpected result: {:?}", x), + } + + tx.send(vec![6, 7, 8]).await.expect("Failed to send"); + + // Second, read data that fits within buf + match t_read.read(&mut buf).await { + Ok(n) if n == 3 => assert_eq!(&buf[..n], &[6, 7, 8]), + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_update_buf_with_some_of_inner_channel_that_can_fit_and_add_rest_to_overflow( + ) { + let (tx, _rx, stream) = InmemoryStream::make(1); + let (mut t_read, _t_write) = stream.into_split(); + + let mut buf = [0u8; 1]; + + tx.send(vec![1, 2, 3, 4, 5]).await.expect("Failed to send"); + + // Attempt a read that places more in overflow + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[1]), + x => panic!("Unexpected result: {:?}", x), + } + + // Verify overflow contains the rest + assert_eq!(&t_read.overflow, &[2, 3, 4, 5]); + + // Queue up extra data that will not be read until overflow is finished + tx.send(vec![6, 7, 8]).await.expect("Failed to send"); + + // Read next data point + match t_read.read(&mut buf).await { + Ok(n) if n == 1 => assert_eq!(&buf[..n], &[2]), + x => panic!("Unexpected result: {:?}", x), + } + + // Verify overflow contains the rest without having added extra data + assert_eq!(&t_read.overflow, &[3, 4, 5]); + } + + #[tokio::test] + async fn read_half_should_yield_pending_if_no_data_available_on_inner_channel() { + let (_tx, _rx, stream) = InmemoryStream::make(1); + let (mut t_read, _t_write) = stream.into_split(); + + let mut buf = [0u8; 1]; + + // Attempt a read that should yield ok with no change, which is what should + // happen when nothing is read into buf + let f = t_read.read(&mut buf); + tokio::pin!(f); + match futures::poll!(f) { + Poll::Pending => {} + x => panic!("Unexpected poll result: {:?}", x), + } + } + + #[tokio::test] + async fn read_half_should_not_update_buf_if_inner_channel_closed() { + let (tx, _rx, stream) = InmemoryStream::make(1); + let (mut t_read, _t_write) = stream.into_split(); + + let mut buf = [0u8; 1]; + + // Drop the channel that would be sending data to the transport + drop(tx); + + // Attempt a read that should yield ok with no change, which is what should + // happen when nothing is read into buf + match t_read.read(&mut buf).await { + Ok(n) if n == 0 => assert_eq!(&buf, &[0]), + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn write_half_should_return_buf_len_if_can_send_immediately() { + let (_tx, mut rx, stream) = InmemoryStream::make(1); + let (_t_read, mut t_write) = stream.into_split(); + + // Write that is not waiting should always succeed with full contents + let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); + assert_eq!(n, 3, "Unexpected byte count returned"); + + // Verify we actually had the data sent + let data = rx.try_recv().expect("Failed to recv data"); + assert_eq!(data, &[1, 2, 3]); + } + + #[tokio::test] + async fn write_half_should_return_support_eventually_sending_by_retrying_when_not_ready() { + let (_tx, mut rx, stream) = InmemoryStream::make(1); + let (_t_read, mut t_write) = stream.into_split(); + + // Queue a write already so that we block on the next one + t_write.write(&[1, 2, 3]).await.expect("Failed to write"); + + // Verify that the next write is pending + let f = t_write.write(&[4, 5]); + tokio::pin!(f); + match futures::poll!(&mut f) { + Poll::Pending => {} + x => panic!("Unexpected poll result: {:?}", x), + } + + // Consume first batch of data so future of second can continue + let data = rx.try_recv().expect("Failed to recv data"); + assert_eq!(data, &[1, 2, 3]); + + // Verify that poll now returns success + match futures::poll!(f) { + Poll::Ready(Ok(n)) if n == 2 => {} + x => panic!("Unexpected poll result: {:?}", x), + } + + // Consume second batch of data + let data = rx.try_recv().expect("Failed to recv data"); + assert_eq!(data, &[4, 5]); + } + + #[tokio::test] + async fn write_half_should_zero_if_inner_channel_closed() { + let (_tx, rx, stream) = InmemoryStream::make(1); + let (_t_read, mut t_write) = stream.into_split(); + + // Drop receiving end that transport would talk to + drop(rx); + + // Channel is dropped, so return 0 to indicate no bytes sent + let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write"); + assert_eq!(n, 0, "Unexpected byte count returned"); + } } diff --git a/distant-core/src/server/distant/handler.rs b/distant-core/src/server/distant/handler.rs index 031104c..8a7f74b 100644 --- a/distant-core/src/server/distant/handler.rs +++ b/distant-core/src/server/distant/handler.rs @@ -103,8 +103,8 @@ pub(super) async fn process( args, detached, } => proc_run(conn_id, state, reply, cmd, args, detached).await, - RequestData::ProcKill { id } => proc_kill(state, id).await, - RequestData::ProcStdin { id, data } => proc_stdin(state, id, data).await, + RequestData::ProcKill { id } => proc_kill(conn_id, state, id).await, + RequestData::ProcStdin { id, data } => proc_stdin(conn_id, state, id, data).await, RequestData::ProcList {} => proc_list(state).await, RequestData::SystemInfo {} => system_info().await, } @@ -458,7 +458,7 @@ where Ok(data) => { let payload = vec![ResponseData::ProcStdout { id, data }]; if !reply_2(payload).await { - error!(" Stdout channel closed", conn_id); + error!(" Stdout channel closed", conn_id, id); break; } @@ -470,12 +470,21 @@ where .await; } Err(x) => { - error!("Invalid data read from stdout pipe: {}", x); + error!( + " Invalid data read from stdout pipe: {}", + conn_id, id, x + ); break; } }, Ok(_) => break, - Err(_) => break, + Err(x) => { + error!( + " Reading stdout failed: {}", + conn_id, id, x + ); + break; + } } } }); @@ -491,7 +500,7 @@ where Ok(data) => { let payload = vec![ResponseData::ProcStderr { id, data }]; if !reply_2(payload).await { - error!(" Stderr channel closed", conn_id); + error!(" Stderr channel closed", conn_id, id); break; } @@ -503,12 +512,21 @@ where .await; } Err(x) => { - error!("Invalid data read from stdout pipe: {}", x); + error!( + " Invalid data read from stdout pipe: {}", + conn_id, id, x + ); break; } }, Ok(_) => break, - Err(_) => break, + Err(x) => { + error!( + " Reading stderr failed: {}", + conn_id, id, x + ); + break; + } } } }); @@ -520,7 +538,7 @@ where while let Some(line) = stdin_rx.recv().await { if let Err(x) = stdin.write_all(line.as_bytes()).await { error!( - " Failed to send stdin to process {}: {}", + " Failed to send stdin: {}", conn_id, id, x ); break; @@ -536,18 +554,22 @@ where let wait_task = tokio::spawn(async move { tokio::select! { status = child.wait() => { - debug!(" Process {} done", conn_id, id); + debug!( + " Completed and waiting on stdout & stderr tasks", + conn_id, + id, + ); // Force stdin task to abort if it hasn't exited as there is no // point to sending any more stdin stdin_task.abort(); if let Err(x) = stderr_task.await { - error!(" Join on stderr task failed: {}", conn_id, x); + error!(" Join on stderr task failed: {}", conn_id, id, x); } if let Err(x) = stdout_task.await { - error!(" Join on stdout task failed: {}", conn_id, x); + error!(" Join on stdout task failed: {}", conn_id, id, x); } state_2.lock().await.remove_process(conn_id, id); @@ -559,7 +581,7 @@ where let payload = vec![ResponseData::ProcDone { id, success, code }]; if !reply_2(payload).await { error!( - " Failed to send done for process {}!", + " Failed to send done", conn_id, id, ); @@ -569,7 +591,7 @@ where let payload = vec![ResponseData::from(x)]; if !reply_2(payload).await { error!( - " Failed to send error for waiting on process {}!", + " Failed to send error for waiting", conn_id, id, ); @@ -579,10 +601,10 @@ where }, _ = kill_rx => { - debug!(" Process {} killed", conn_id, id); + debug!(" Killing", conn_id, id); if let Err(x) = child.kill().await { - error!(" Unable to kill process {}: {}", conn_id, id, x); + error!(" Unable to kill: {}", conn_id, id, x); } // Force stdin task to abort if it hasn't exited as there is no @@ -590,24 +612,24 @@ where stdin_task.abort(); if let Err(x) = stderr_task.await { - error!(" Join on stderr task failed: {}", conn_id, x); + error!(" Join on stderr task failed: {}", conn_id, id, x); } if let Err(x) = stdout_task.await { - error!(" Join on stdout task failed: {}", conn_id, x); + error!(" Join on stdout task failed: {}", conn_id, id, x); } // Wait for the child after being killed to ensure that it has been cleaned // up at the operating system level if let Err(x) = child.wait().await { - error!(" Failed to wait on killed process {}: {}", conn_id, id, x); + error!(" Failed to wait after killed: {}", conn_id, id, x); } state_2.lock().await.remove_process(conn_id, id); let payload = vec![ResponseData::ProcDone { id, success: false, code: None }]; if !reply_2(payload).await { - error!(" Failed to send done for process {}!", conn_id, id); + error!(" Failed to send done", conn_id, id); } } } @@ -625,7 +647,7 @@ where }) } -async fn proc_kill(state: HState, id: usize) -> Result { +async fn proc_kill(conn_id: usize, state: HState, id: usize) -> Result { if let Some(process) = state.lock().await.processes.remove(&id) { if process.kill() { return Ok(Outgoing::from(ResponseData::Ok)); @@ -634,11 +656,19 @@ async fn proc_kill(state: HState, id: usize) -> Result { Err(ServerError::IoError(io::Error::new( io::ErrorKind::BrokenPipe, - "Unable to send kill signal to process", + format!( + " Unable to send kill signal to process", + conn_id, id + ), ))) } -async fn proc_stdin(state: HState, id: usize, data: String) -> Result { +async fn proc_stdin( + conn_id: usize, + state: HState, + id: usize, + data: String, +) -> Result { if let Some(process) = state.lock().await.processes.get(&id) { if process.send_stdin(data).await { return Ok(Outgoing::from(ResponseData::Ok)); @@ -647,7 +677,10 @@ async fn proc_stdin(state: HState, id: usize, data: String) -> Result Unable to send stdin to process", + conn_id, id, + ), ))) } diff --git a/distant-ssh2/src/handler.rs b/distant-ssh2/src/handler.rs index f51e080..b42a9cb 100644 --- a/distant-ssh2/src/handler.rs +++ b/distant-ssh2/src/handler.rs @@ -672,7 +672,7 @@ where Ok(data) => { let payload = vec![ResponseData::ProcStdout { id, data }]; if !reply_2(payload).await { - error!(" Stdout channel closed", id); + error!(" Stdout channel closed", id); break; } @@ -685,7 +685,7 @@ where } Err(x) => { error!( - " Invalid data read from stdout pipe: {}", + " Invalid data read from stdout pipe: {}", id, x ); break; @@ -698,7 +698,10 @@ where tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS)) .await; } - Err(_) => break, + Err(x) => { + error!(" Stdout unexpectedly closed: {}", id, x); + break; + } } } }); @@ -713,7 +716,7 @@ where Ok(data) => { let payload = vec![ResponseData::ProcStderr { id, data }]; if !reply_2(payload).await { - error!(" Stderr channel closed", id); + error!(" Stderr channel closed", id); break; } @@ -726,7 +729,7 @@ where } Err(x) => { error!( - " Invalid data read from stderr pipe: {}", + " Invalid data read from stderr pipe: {}", id, x ); break; @@ -739,7 +742,10 @@ where tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS)) .await; } - Err(_) => break, + Err(x) => { + error!(" Stderr unexpectedly closed: {}", id, x); + break; + } } } }); @@ -747,7 +753,7 @@ where let stdin_task = tokio::spawn(async move { while let Some(line) = stdin_rx.recv().await { if let Err(x) = stdin.write_all(line.as_bytes()) { - error!(" Failed to send stdin: {}", id, x); + error!(" Failed to send stdin: {}", id, x); break; } } @@ -770,7 +776,7 @@ where success = status.success(); } Err(x) => { - error!(" Waiting on process failed: {}", id, x); + error!(" Waiting on process failed: {}", id, x); } } } @@ -781,10 +787,10 @@ where stdin_task.abort(); if should_kill { - debug!(" Process killed", id); + debug!(" Killing", id); if let Err(x) = child.kill() { - error!(" Unable to kill process: {}", id, x); + error!(" Unable to kill process: {}", id, x); } // NOTE: At the moment, child.kill does nothing for wezterm_ssh::SshChildProcess; @@ -801,15 +807,18 @@ where .await; } } else { - debug!(" Process done", id); + debug!( + " Completed and waiting on stdout & stderr tasks", + id + ); } if let Err(x) = stderr_task.await { - error!(" Join on stderr task failed: {}", id, x); + error!(" Join on stderr task failed: {}", id, x); } if let Err(x) = stdout_task.await { - error!(" Join on stdout task failed: {}", id, x); + error!(" Join on stdout task failed: {}", id, x); } state_2.lock().await.processes.remove(&id); @@ -821,7 +830,7 @@ where }]; if !reply_2(payload).await { - error!(" Failed to send done!", id,); + error!(" Failed to send done", id,); } }); }); @@ -845,7 +854,7 @@ async fn proc_kill( Err(io::Error::new( io::ErrorKind::BrokenPipe, - "Unable to send kill signal to process", + format!(" Unable to send kill signal to process", id), )) } @@ -863,7 +872,7 @@ async fn proc_stdin( Err(io::Error::new( io::ErrorKind::BrokenPipe, - "Unable to send stdin to process", + format!(" Unable to send stdin to process", id), )) } diff --git a/distant-ssh2/src/lib.rs b/distant-ssh2/src/lib.rs index 510fb48..01972fa 100644 --- a/distant-ssh2/src/lib.rs +++ b/distant-ssh2/src/lib.rs @@ -507,17 +507,20 @@ impl Ssh2Session { if let Err(x) = handler::process(wez_session.clone(), Arc::clone(&state), req, tx.clone()).await { - error!("{}", x); + error!("Ssh session receiver handler failed: {}", x); } } + debug!("Ssh receiver task is now closed"); }); tokio::spawn(async move { while let Some(res) = rx.recv().await { - if t_write.send(res).await.is_err() { + if let Err(x) = t_write.send(res).await { + error!("Ssh session sender failed: {}", x); break; } } + debug!("Ssh sender task is now closed"); }); Ok(session)