Add support for piping stdin to remote proccess and reduce size of packets to just include the total bytes as a header

pull/38/head
Chip Senkbeil 3 years ago
parent f59ae7f6ed
commit 3a2749fd7f
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -4,16 +4,13 @@ use std::convert::TryInto;
use tokio::io;
use tokio_util::codec::{Decoder, Encoder};
/// Represents a marker to indicate the beginning of the next message
static MSG_MARKER: &'static [u8] = b";msg;";
/// Total size in bytes that is used for storing length
static LEN_SIZE: usize = 8;
#[inline]
fn frame_size(msg_size: usize) -> usize {
// MARKER + u64 (8 bytes) + msg size
MSG_MARKER.len() + LEN_SIZE + msg_size
// u64 (8 bytes) + msg size
LEN_SIZE + msg_size
}
/// Possible errors that can occur during encoding and decoding
@ -34,7 +31,6 @@ impl<'a> Encoder<&'a [u8]> for DistantCodec {
fn encode(&mut self, item: &'a [u8], dst: &mut BytesMut) -> Result<(), Self::Error> {
// Add our full frame to the bytes
dst.reserve(frame_size(item.len()));
dst.put(MSG_MARKER);
dst.put_u64(item.len() as u64);
dst.put(item);
@ -47,31 +43,18 @@ impl Decoder for DistantCodec {
type Error = DistantCodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// First, check if we have more data than just our markers, if not we say that it's okay
// but that we're waiting
if src.len() <= (MSG_MARKER.len() + LEN_SIZE) {
// First, check if we have more data than just our frame's message length
if src.len() <= LEN_SIZE {
return Ok(None);
}
// Second, verify that our first N bytes match our start marker
let marker_start = &src[..MSG_MARKER.len()];
if marker_start != MSG_MARKER {
return Err(DistantCodecError::CorruptMarker(Bytes::copy_from_slice(
marker_start,
)));
}
// Third, retrieve total size of our msg
let msg_len = u64::from_be_bytes(
src[MSG_MARKER.len()..MSG_MARKER.len() + LEN_SIZE]
.try_into()
.unwrap(),
);
// Second, retrieve total size of our frame's message
let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap());
// Fourth, return our msg if it's available, stripping it of the start and end markers
// Third, return our msg if it's available, stripping it of the length data
let frame_len = frame_size(msg_len as usize);
if src.len() >= frame_len {
let data = src[MSG_MARKER.len() + LEN_SIZE..frame_len].to_vec();
let data = src[LEN_SIZE..frame_len].to_vec();
// Advance so frame is no longer kept around
src.advance(frame_len);

@ -117,7 +117,7 @@ pub struct TransportWriteHalf {
}
impl TransportWriteHalf {
/// Sends some data across the wire
/// Sends some data across the wire, waiting for it to completely send
pub async fn send<T: Serialize>(&mut self, data: T) -> Result<(), TransportError> {
// Serialize, encrypt, and then (TODO) sign
// NOTE: Cannot used packed implementation for now due to issues with deserialization

@ -5,7 +5,8 @@ use crate::{
utils::{Session, SessionError},
};
use derive_more::{Display, Error, From};
use tokio::io;
use log::*;
use tokio::{io, sync::mpsc};
use tokio_stream::StreamExt;
#[derive(Debug, Display, Error, From)]
@ -36,12 +37,40 @@ async fn run_async(cmd: SendSubcommand, _opt: CommonOpt) -> Result<(), Error> {
};
let res = client.send(req).await?;
// Store the spawned process id for using in sending stdin (if we spawned a proc)
let proc_id = match &res.payload {
ResponsePayload::ProcStart { id } => *id,
_ => 0,
};
print_response(cmd.format, res)?;
// If we are executing a process and not detaching, we want to continue receiving
// responses sent to us
if is_proc_req && not_detach {
let mut stream = client.to_response_stream();
// We also want to spawn a task to handle sending stdin to the remote process
let mut rx = spawn_stdin_reader();
tokio::spawn(async move {
while let Some(line) = rx.recv().await {
trace!("Client sending stdin: {:?}", line);
let req = Request::from(RequestPayload::ProcStdin {
id: proc_id,
data: line.into_bytes(),
});
let result = client.send(req).await;
if let Err(x) = result {
error!(
"Failed to send stdin to remote process ({}): {}",
proc_id, x
);
}
}
});
while let Some(res) = stream.next().await {
let res = res.map_err(|_| {
io::Error::new(
@ -90,6 +119,32 @@ fn print_response(fmt: ResponseFormat, res: Response) -> io::Result<()> {
Ok(())
}
fn spawn_stdin_reader() -> mpsc::Receiver<String> {
let (tx, rx) = mpsc::channel(1);
// NOTE: Using blocking I/O per tokio's advice to read from stdin line-by-line and then
// pass the results to a separate async handler to forward to the remote process
std::thread::spawn(move || {
let stdin = std::io::stdin();
loop {
let mut line = String::new();
if stdin.read_line(&mut line).is_ok() {
if let Err(x) = tx.blocking_send(line) {
error!(
"Failed to pass along stdin to be sent to remote process: {}",
x
);
}
} else {
break;
}
}
});
rx
}
fn format_response(fmt: ResponseFormat, res: Response) -> io::Result<String> {
Ok(match fmt {
ResponseFormat::Human => format_human(res),

Loading…
Cancel
Save