diff --git a/core/src/net/transport/mod.rs b/core/src/net/transport/mod.rs index 3ced834..921e5f4 100644 --- a/core/src/net/transport/mod.rs +++ b/core/src/net/transport/mod.rs @@ -175,7 +175,7 @@ where /// when communicating across the wire pub async fn from_handshake(stream: T, auth_key: Option>) -> io::Result { let connection_tag = stream.to_connection_tag(); - trace!("Beginning handshake for {}", connection_tag); + trace!("Beginning handshake with {}", connection_tag); // First, wrap the raw stream in our framed codec let mut conn = Framed::new(stream, DistantCodec); @@ -187,6 +187,7 @@ where let public_key = EncodedPoint::from(private_key.public_key()); // Fourth, share a random salt and the public key with the server as our first message + trace!("Handshake with {} sending public key", connection_tag); let salt = Salt::generate(SALT_LEN).map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; let mut data = Vec::new(); data.extend_from_slice(salt.as_ref()); @@ -196,6 +197,10 @@ where .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; // Fifth, wait for a response that we will assume is the other side's salt & public key + trace!( + "Handshake with {} waiting for remote public key", + connection_tag + ); let data = conn.next().await.ok_or_else(|| { io::Error::new( io::ErrorKind::UnexpectedEof, @@ -221,6 +226,7 @@ where // Seventh, establish a shared secret that is NOT uniformly random, so we can't // directly use it as our encryption key (32 bytes in length) + trace!("Handshake with {} computing shared secret", connection_tag); let shared_secret = private_key.diffie_hellman(&other_public_key); // Eighth, convert our secret key into an orion password that we'll use to derive @@ -241,11 +247,12 @@ where .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?; // Tenth, derive a higher-entropy key from our shared secret + trace!("Handshake with {} deriving encryption key", connection_tag); let derived_key = kdf::derive_key(&password, &mixed_salt, 3, 1 << 16, 32) .map_err(|x| io::Error::new(io::ErrorKind::Other, x))?; let crypt_key = Arc::new(derived_key); - trace!("Finished handshake for {}", connection_tag); + trace!("Finished handshake with {}", connection_tag); Ok(Self { conn, diff --git a/src/subcommand/action.rs b/src/subcommand/action.rs index dfc80d4..f013546 100644 --- a/src/subcommand/action.rs +++ b/src/subcommand/action.rs @@ -1,7 +1,7 @@ use crate::{ exit::{ExitCode, ExitCodeError}, link::RemoteProcessLink, - opt::{ActionSubcommand, CommonOpt, SessionInput}, + opt::{ActionSubcommand, CommonOpt, Format, SessionInput}, output::ResponseOut, session::CliSession, utils, @@ -123,10 +123,12 @@ async fn start( where T: DataStream + 'static, { + let is_shell_format = matches!(cmd.format, Format::Shell); + match (cmd.interactive, cmd.operation) { - // ProcRun request is specially handled and we ignore interactive as + // ProcRun request w/ shell format is specially handled and we ignore interactive as // the stdin will be used for sending ProcStdin to remote process - (_, Some(RequestData::ProcRun { cmd, args })) => { + (_, Some(RequestData::ProcRun { cmd, args })) if is_shell_format => { let mut proc = RemoteProcess::spawn(utils::new_tenant(), session, cmd, args).await?; // If we also parsed an LSP's initialize request for its session, we want to forward diff --git a/tests/cli/action/proc_run.rs b/tests/cli/action/proc_run.rs index 02cf0cd..4c042d4 100644 --- a/tests/cli/action/proc_run.rs +++ b/tests/cli/action/proc_run.rs @@ -1,4 +1,7 @@ -use crate::cli::{fixtures::*, utils::random_tenant}; +use crate::cli::{ + fixtures::*, + utils::{distant_subcommand, friendly_recv_line, random_tenant, spawn_line_reader}, +}; use assert_cmd::Command; use assert_fs::prelude::*; use distant::ExitCode; @@ -7,6 +10,7 @@ use distant_core::{ Request, RequestData, Response, ResponseData, }; use rstest::*; +use std::{io::Write, time::Duration}; lazy_static::lazy_static! { static ref TEMP_SCRIPT_DIR: assert_fs::TempDir = assert_fs::TempDir::new().unwrap(); @@ -162,6 +166,249 @@ fn should_support_json_to_execute_program_and_return_exit_status(mut action_cmd: ); } +#[rstest] +fn should_support_json_to_capture_and_print_stdout(ctx: &'_ DistantServerCtx) { + let output = String::from("some output"); + let req = Request { + id: rand::random(), + tenant: random_tenant(), + payload: vec![RequestData::ProcRun { + cmd: SCRIPT_RUNNER.to_string(), + args: vec![ + ECHO_ARGS_TO_STDOUT_SH.to_str().unwrap().to_string(), + output.to_string(), + ], + }], + }; + + // distant action --format json --interactive + let mut child = distant_subcommand(ctx, "action") + .args(&["--format", "json"]) + .arg("--interactive") + .spawn() + .unwrap(); + + let mut stdin = child.stdin.take().unwrap(); + let stdout = spawn_line_reader(child.stdout.take().unwrap()); + let stderr = spawn_line_reader(child.stderr.take().unwrap()); + + // Send our request as json + let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); + stdin.write_all(req_string.as_bytes()).unwrap(); + stdin.flush().unwrap(); + + // Get the indicator of a process started (first line returned can take ~7 seconds due to the + // handshake cost) + let out = + friendly_recv_line(&stdout, Duration::from_secs(30)).expect("Failed to get proc start"); + let res: Response = serde_json::from_str(&out).unwrap(); + assert!( + matches!(res.payload[0], ResponseData::ProcStart { .. }), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Get stdout from process and verify it + let out = + friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc stdout"); + let res: Response = serde_json::from_str(&out).unwrap(); + match &res.payload[0] { + ResponseData::ProcStdout { data, .. } => assert_eq!(data, &output), + x => panic!("Unexpected response: {:?}", x), + }; + + // Get the indicator of a process completion + let out = friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc done"); + let res: Response = serde_json::from_str(&out).unwrap(); + match &res.payload[0] { + ResponseData::ProcDone { success, .. } => { + assert!(success, "Process failed unexpectedly"); + } + x => panic!("Unexpected response: {:?}", x), + }; + + // Verify that we received nothing on stderr channel + assert!( + stderr.try_recv().is_err(), + "Unexpectedly got result on stderr channel" + ); +} + +#[rstest] +fn should_support_json_to_capture_and_print_stderr(ctx: &'_ DistantServerCtx) { + let output = String::from("some output"); + let req = Request { + id: rand::random(), + tenant: random_tenant(), + payload: vec![RequestData::ProcRun { + cmd: SCRIPT_RUNNER.to_string(), + args: vec![ + ECHO_ARGS_TO_STDERR_SH.to_str().unwrap().to_string(), + output.to_string(), + ], + }], + }; + + // distant action --format json --interactive + let mut child = distant_subcommand(ctx, "action") + .args(&["--format", "json"]) + .arg("--interactive") + .spawn() + .unwrap(); + + let mut stdin = child.stdin.take().unwrap(); + let stdout = spawn_line_reader(child.stdout.take().unwrap()); + let stderr = spawn_line_reader(child.stderr.take().unwrap()); + + // Send our request as json + let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); + stdin.write_all(req_string.as_bytes()).unwrap(); + stdin.flush().unwrap(); + + // Get the indicator of a process started (first line returned can take ~7 seconds due to the + // handshake cost) + let out = + friendly_recv_line(&stdout, Duration::from_secs(30)).expect("Failed to get proc start"); + let res: Response = serde_json::from_str(&out).unwrap(); + assert!( + matches!(res.payload[0], ResponseData::ProcStart { .. }), + "Unexpected response: {:?}", + res.payload[0] + ); + + // Get stderr from process and verify it + let out = + friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc stderr"); + let res: Response = serde_json::from_str(&out).unwrap(); + match &res.payload[0] { + ResponseData::ProcStderr { data, .. } => assert_eq!(data, &output), + x => panic!("Unexpected response: {:?}", x), + }; + + // Get the indicator of a process completion + let out = friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc done"); + let res: Response = serde_json::from_str(&out).unwrap(); + match &res.payload[0] { + ResponseData::ProcDone { success, .. } => { + assert!(success, "Process failed unexpectedly"); + } + x => panic!("Unexpected response: {:?}", x), + }; + + // Verify that we received nothing on stderr channel + assert!( + stderr.try_recv().is_err(), + "Unexpectedly got result on stderr channel" + ); +} + +#[rstest] +fn should_support_json_to_forward_stdin_to_remote_process(ctx: &'_ DistantServerCtx) { + let req = Request { + id: rand::random(), + tenant: random_tenant(), + payload: vec![RequestData::ProcRun { + cmd: SCRIPT_RUNNER.to_string(), + args: vec![ECHO_STDIN_TO_STDOUT_SH.to_str().unwrap().to_string()], + }], + }; + + // distant action --format json --interactive + let mut child = distant_subcommand(ctx, "action") + .args(&["--format", "json"]) + .arg("--interactive") + .args(&["--log-file", "/tmp/test.log", "-vvv"]) + .spawn() + .unwrap(); + + let mut stdin = child.stdin.take().unwrap(); + let stdout = spawn_line_reader(child.stdout.take().unwrap()); + let stderr = spawn_line_reader(child.stderr.take().unwrap()); + + // Send our request as json + let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); + stdin.write_all(req_string.as_bytes()).unwrap(); + stdin.flush().unwrap(); + + // Get the indicator of a process started (first line returned can take ~7 seconds due to the + // handshake cost) + let out = + friendly_recv_line(&stdout, Duration::from_secs(30)).expect("Failed to get proc start"); + let res: Response = serde_json::from_str(&out).unwrap(); + let id = match &res.payload[0] { + ResponseData::ProcStart { id } => *id, + x => panic!("Unexpected response: {:?}", x), + }; + + // Send stdin to remote process + let req = Request { + id: rand::random(), + tenant: random_tenant(), + payload: vec![RequestData::ProcStdin { + id, + data: String::from("hello world\n"), + }], + }; + let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); + stdin.write_all(req_string.as_bytes()).unwrap(); + stdin.flush().unwrap(); + + // Should receive ok message + let out = friendly_recv_line(&stdout, Duration::from_secs(1)) + .expect("Failed to get ok response from proc stdin"); + let res: Response = serde_json::from_str(&out).unwrap(); + match &res.payload[0] { + ResponseData::Ok => {} + x => panic!("Unexpected response: {:?}", x), + }; + + // Get stdout from process and verify it + let out = + friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc stdout"); + let res: Response = serde_json::from_str(&out).unwrap(); + match &res.payload[0] { + ResponseData::ProcStdout { data, .. } => assert_eq!(data, "hello world\n"), + x => panic!("Unexpected response: {:?}", x), + }; + + // Kill the remote process since it only terminates when stdin closes, but we + // want to verify that we get a proc done is some manner, which won't happen + // if stdin closes as our interactive process will also close + let req = Request { + id: rand::random(), + tenant: random_tenant(), + payload: vec![RequestData::ProcKill { id }], + }; + let req_string = format!("{}\n", serde_json::to_string(&req).unwrap()); + stdin.write_all(req_string.as_bytes()).unwrap(); + stdin.flush().unwrap(); + + // Should receive ok message + let out = friendly_recv_line(&stdout, Duration::from_secs(1)) + .expect("Failed to get ok response from proc stdin"); + let res: Response = serde_json::from_str(&out).unwrap(); + match &res.payload[0] { + ResponseData::Ok => {} + x => panic!("Unexpected response: {:?}", x), + }; + + // Get the indicator of a process completion + let out = friendly_recv_line(&stdout, Duration::from_secs(1)).expect("Failed to get proc done"); + let res: Response = serde_json::from_str(&out).unwrap(); + match &res.payload[0] { + ResponseData::ProcDone { success, .. } => { + assert!(!success, "Process succeeded unexpectedly"); + } + x => panic!("Unexpected response: {:?}", x), + }; + + // Verify that we received nothing on stderr channel + assert!( + stderr.try_recv().is_err(), + "Unexpectedly got result on stderr channel" + ); +} + #[rstest] fn should_support_json_output_for_error(mut action_cmd: Command) { let req = Request { diff --git a/tests/cli/utils.rs b/tests/cli/utils.rs index e950e84..f00c253 100644 --- a/tests/cli/utils.rs +++ b/tests/cli/utils.rs @@ -1,5 +1,12 @@ +use crate::cli::fixtures::DistantServerCtx; use predicates::prelude::*; -use std::path::PathBuf; +use std::{ + env, io, + path::PathBuf, + process::{Command, Stdio}, + sync::mpsc, + time::{Duration, Instant}, +}; lazy_static::lazy_static! { /// Predicate that checks for a single line that is a failure @@ -38,3 +45,106 @@ pub fn init_logging(path: impl Into) -> flexi_logger::LoggerHandle { logger.start().expect("Failed to initialize logger") } + +pub fn friendly_recv_line( + receiver: &mpsc::Receiver, + duration: Duration, +) -> io::Result { + let start = Instant::now(); + loop { + if let Ok(line) = receiver.try_recv() { + break Ok(line); + } + + if start.elapsed() > duration { + return Err(io::Error::new( + io::ErrorKind::TimedOut, + format!("Failed to receive line after {}s", duration.as_secs_f32()), + )); + } + + std::thread::yield_now(); + } +} + +pub fn spawn_line_reader(mut reader: T) -> mpsc::Receiver +where + T: std::io::Read + Send + 'static, +{ + let id = rand::random::(); + let (tx, rx) = mpsc::channel(); + std::thread::spawn(move || { + let mut buf = String::new(); + let mut tmp = [0; 1024]; + while let Ok(n) = reader.read(&mut tmp) { + if n == 0 { + break; + } + + let data = String::from_utf8_lossy(&tmp[..n]); + buf.push_str(data.as_ref()); + + // Send all complete lines + match buf.rfind('\n') { + Some(idx) => { + let remaining = buf.split_off(idx + 1); + for line in buf.lines() { + tx.send(line.to_string()).unwrap(); + } + buf = remaining; + } + None => {} + } + } + + // If something is remaining at end, also send it + if !buf.is_empty() { + tx.send(buf).unwrap(); + } + }); + + rx +} + +/// Produces a new command for distant using the given subcommand +pub fn distant_subcommand(ctx: &DistantServerCtx, subcommand: &str) -> Command { + let mut cmd = Command::new(cargo_bin(env!("CARGO_PKG_NAME"))); + cmd.arg(subcommand) + .args(&["--session", "environment"]) + .env("DISTANT_HOST", ctx.addr.ip().to_string()) + .env("DISTANT_PORT", ctx.addr.port().to_string()) + .env("DISTANT_AUTH_KEY", ctx.auth_key.as_str()) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + cmd +} + +/// Look up the path to a cargo-built binary within an integration test +/// +/// Taken from https://github.com/assert-rs/assert_cmd/blob/036ef47b8ad170dcaf4eaf4412c0b48fd5b6ef6e/src/cargo.rs#L199 +fn cargo_bin>(name: S) -> PathBuf { + cargo_bin_str(name.as_ref()) +} + +fn cargo_bin_str(name: &str) -> PathBuf { + let env_var = format!("CARGO_BIN_EXE_{}", name); + std::env::var_os(&env_var) + .map(|p| p.into()) + .unwrap_or_else(|| target_dir().join(format!("{}{}", name, env::consts::EXE_SUFFIX))) +} + +// Adapted from +// https://github.com/rust-lang/cargo/blob/485670b3983b52289a2f353d589c57fae2f60f82/tests/testsuite/support/mod.rs#L507 +fn target_dir() -> PathBuf { + env::current_exe() + .ok() + .map(|mut path| { + path.pop(); + if path.ends_with("deps") { + path.pop(); + } + path + }) + .unwrap() +}