Add session ext (#53)

* Add `SessionExt` trait for friendlier methods
* Create `Mailbox` and internal `PostOffice` to manage responses to requests
* Refactor `Session` to use a new `SessionChannel` underneath
* Refactor `Response` to always include an origin_id field instead of being optional
* Update `ProcStdout`, `ProcStderr`, and `ProcDone` to include origin id
* Replace `verbose` option with `log-level`
pull/55/head v0.14.0
Chip Senkbeil 3 years ago committed by GitHub
parent c45aea8fe7
commit c4d1011b14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,10 +12,10 @@ servers that operate on remote machines and clients that talk to them.
- Asynchronous in nature, powered by [`tokio`](https://tokio.rs/)
- Data is serialized to send across the wire via [`CBOR`](https://cbor.io/)
- Encryption & authentication are handled via [`orion`](https://crates.io/crates/orion)
- [XChaCha20Poly1305](https://cryptopp.com/wiki/XChaCha20Poly1305) for an authenticated encryption scheme
- [BLAKE2b-256](https://www.blake2.net/) in keyed mode for a second authentication
- [Elliptic Curve Diffie-Hellman](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman) (ECDH) for key exchange
- Encryption & authentication are handled via
[XChaCha20Poly1305](https://tools.ietf.org/html/rfc8439) for an authenticated
encryption scheme via
[RustCrypto/ChaCha20Poly1305](https://github.com/RustCrypto/AEADs/tree/master/chacha20poly1305)
## Installation

@ -1,8 +1,5 @@
use super::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout};
use crate::{
client::Session,
net::{Codec, DataStream},
};
use crate::client::Session;
use futures::stream::{Stream, StreamExt};
use std::{
fmt::Write,
@ -26,16 +23,12 @@ pub struct RemoteLspProcess {
impl RemoteLspProcess {
/// Spawns the specified process on the remote machine using the given session, treating
/// the process like an LSP server
pub async fn spawn<T, U>(
tenant: String,
session: Session<T, U>,
cmd: String,
pub async fn spawn(
tenant: impl Into<String>,
session: &mut Session,
cmd: impl Into<String>,
args: Vec<String>,
) -> Result<Self, RemoteProcessError>
where
T: DataStream + 'static,
U: Codec + Send + 'static,
{
) -> Result<Self, RemoteProcessError> {
let mut inner = RemoteProcess::spawn(tenant, session, cmd, args).await?;
let stdin = inner.stdin.take().map(RemoteLspStdin::new);
let stdout = inner.stdout.take().map(RemoteLspStdout::new);
@ -266,13 +259,16 @@ mod tests {
// Configures an lsp process with a means to send & receive data from outside
async fn spawn_lsp_process() -> (Transport<InmemoryStream, PlainCodec>, RemoteLspProcess) {
let (mut t1, t2) = Transport::make_pair();
let session = Session::initialize(t2).unwrap();
let spawn_task = tokio::spawn(RemoteLspProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let mut session = Session::initialize(t2).unwrap();
let spawn_task = tokio::spawn(async move {
RemoteLspProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = t1.receive::<Request>().await.unwrap().unwrap();
@ -280,7 +276,7 @@ mod tests {
// Send back a response through the session
t1.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::ProcStart { id: rand::random() }],
))
.await
@ -524,7 +520,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
@ -561,7 +557,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: msg_a.to_string(),
@ -580,7 +576,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: msg_b.to_string(),
@ -615,7 +611,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: format!("{}{}", msg, extra),
@ -659,7 +655,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: format!("{}{}", msg_1, msg_2),
@ -694,7 +690,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
@ -725,7 +721,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
@ -762,7 +758,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: msg_a.to_string(),
@ -781,7 +777,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: msg_b.to_string(),
@ -816,7 +812,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: format!("{}{}", msg, extra),
@ -860,7 +856,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: format!("{}{}", msg_1, msg_2),
@ -895,7 +891,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
proc.origin_id,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({

@ -1,8 +1,8 @@
use crate::{
client::Session,
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, RequestData, Response, ResponseData},
net::{Codec, DataStream, TransportError},
client::{Mailbox, Session, SessionChannel},
constants::CLIENT_MAILBOX_CAPACITY,
data::{Request, RequestData, ResponseData},
net::TransportError,
};
use derive_more::{Display, Error, From};
use log::*;
@ -14,9 +14,6 @@ use tokio::{
#[derive(Debug, Display, Error, From)]
pub enum RemoteProcessError {
/// When the process receives an unexpected response
BadResponse,
/// When attempting to relay stdout/stderr over channels, but the channels fail
ChannelDead,
@ -37,6 +34,9 @@ pub struct RemoteProcess {
/// Id of the process
id: usize,
/// Id used to map back to mailbox
pub(crate) origin_id: usize,
/// Task that forwards stdin to the remote process by bundling it as stdin requests
req_task: JoinHandle<Result<(), RemoteProcessError>>,
@ -59,39 +59,55 @@ pub struct RemoteProcess {
impl RemoteProcess {
/// Spawns the specified process on the remote machine using the given session
pub async fn spawn<T, U>(
tenant: String,
mut session: Session<T, U>,
cmd: String,
pub async fn spawn(
tenant: impl Into<String>,
session: &mut Session,
cmd: impl Into<String>,
args: Vec<String>,
) -> Result<Self, RemoteProcessError>
where
T: DataStream + 'static,
U: Codec + Send + 'static,
{
// Submit our run request and wait for a response
let res = session
.send(Request::new(
) -> Result<Self, RemoteProcessError> {
let tenant = tenant.into();
let cmd = cmd.into();
// Submit our run request and get back a mailbox for responses
let mut mailbox = session
.mail(Request::new(
tenant.as_str(),
vec![RequestData::ProcRun { cmd, args }],
))
.await?;
// We expect a singular response back
if res.payload.len() != 1 {
return Err(RemoteProcessError::BadResponse);
}
// Response should be proc starting
let id = match res.payload.into_iter().next().unwrap() {
ResponseData::ProcStart { id } => id,
_ => return Err(RemoteProcessError::BadResponse),
// Wait until we get the first response, and get id from proc started
let (id, origin_id) = match mailbox.next().await {
Some(res) if res.payload.len() != 1 => {
return Err(RemoteProcessError::TransportError(TransportError::IoError(
io::Error::new(io::ErrorKind::InvalidData, "Got wrong payload size"),
)));
}
Some(res) => {
let origin_id = res.origin_id;
match res.payload.into_iter().next().unwrap() {
ResponseData::ProcStart { id } => (id, origin_id),
x => {
return Err(RemoteProcessError::TransportError(TransportError::IoError(
io::Error::new(
io::ErrorKind::InvalidData,
format!("Got response type of {}", x.as_ref()),
),
)))
}
}
}
None => {
return Err(RemoteProcessError::TransportError(TransportError::IoError(
io::Error::from(io::ErrorKind::ConnectionAborted),
)))
}
};
// Create channels for our stdin/stdout/stderr
let (stdin_tx, stdin_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (stdout_tx, stdout_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (stderr_tx, stderr_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (stdin_tx, stdin_rx) = mpsc::channel(CLIENT_MAILBOX_CAPACITY);
let (stdout_tx, stdout_rx) = mpsc::channel(CLIENT_MAILBOX_CAPACITY);
let (stderr_tx, stderr_rx) = mpsc::channel(CLIENT_MAILBOX_CAPACITY);
// Used to terminate request task, either explicitly by the process or internally
// by the response task when it terminates
@ -100,18 +116,19 @@ impl RemoteProcess {
// Now we spawn a task to handle future responses that are async
// such as ProcStdout, ProcStderr, and ProcDone
let kill_tx_2 = kill_tx.clone();
let broadcast = session.broadcast.take().unwrap();
let res_task = tokio::spawn(async move {
process_incoming_responses(id, broadcast, stdout_tx, stderr_tx, kill_tx_2).await
process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2).await
});
// Spawn a task that takes stdin from our channel and forwards it to the remote process
let channel = session.clone_channel();
let req_task = tokio::spawn(async move {
process_outgoing_requests(tenant, id, session, stdin_rx, kill_rx).await
process_outgoing_requests(tenant, id, channel, stdin_rx, kill_rx).await
});
Ok(Self {
id,
origin_id,
req_task,
res_task,
stdin: Some(RemoteStdin(stdin_tx)),
@ -196,22 +213,18 @@ impl RemoteStderr {
/// Helper function that loops, processing outgoing stdin requests to a remote process as well as
/// supporting a kill request to terminate the remote process
async fn process_outgoing_requests<T, U>(
async fn process_outgoing_requests(
tenant: String,
id: usize,
mut session: Session<T, U>,
mut channel: SessionChannel,
mut stdin_rx: mpsc::Receiver<String>,
mut kill_rx: mpsc::Receiver<()>,
) -> Result<(), RemoteProcessError>
where
T: DataStream,
U: Codec,
{
) -> Result<(), RemoteProcessError> {
let result = loop {
tokio::select! {
data = stdin_rx.recv() => {
match data {
Some(data) => session.fire(
Some(data) => channel.fire(
Request::new(
tenant.as_str(),
vec![RequestData::ProcStdin { id, data }]
@ -222,12 +235,10 @@ where
}
msg = kill_rx.recv() => {
if msg.is_some() {
session
.fire(Request::new(
tenant.as_str(),
vec![RequestData::ProcKill { id }],
))
.await?;
channel.fire(Request::new(
tenant.as_str(),
vec![RequestData::ProcKill { id }],
)).await?;
break Ok(());
} else {
break Err(RemoteProcessError::ChannelDead);
@ -243,12 +254,12 @@ where
/// Helper function that loops, processing incoming stdout & stderr requests from a remote process
async fn process_incoming_responses(
proc_id: usize,
mut broadcast: mpsc::Receiver<Response>,
mut mailbox: Mailbox,
stdout_tx: mpsc::Sender<String>,
stderr_tx: mpsc::Sender<String>,
kill_tx: mpsc::Sender<()>,
) -> Result<(bool, Option<i32>), RemoteProcessError> {
while let Some(res) = broadcast.recv().await {
while let Some(res) = mailbox.next().await {
// Check if any of the payload data is the termination
let exit_status = res.payload.iter().find_map(|data| match data {
ResponseData::ProcDone { id, success, code } if *id == proc_id => {
@ -291,61 +302,68 @@ async fn process_incoming_responses(
mod tests {
use super::*;
use crate::{
data::{Error, ErrorKind},
data::{Error, ErrorKind, Response},
net::{InmemoryStream, PlainCodec, Transport},
};
fn make_session() -> (
Transport<InmemoryStream, PlainCodec>,
Session<InmemoryStream, PlainCodec>,
) {
fn make_session() -> (Transport<InmemoryStream, PlainCodec>, Session) {
let (t1, t2) = Transport::make_pair();
(t1, Session::initialize(t2).unwrap())
}
#[tokio::test]
async fn spawn_should_return_bad_response_if_payload_size_unexpected() {
let (mut transport, session) = make_session();
async fn spawn_should_return_invalid_data_if_payload_size_unexpected() {
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
transport
.send(Response::new("test-tenant", Some(req.id), Vec::new()))
.send(Response::new("test-tenant", req.id, Vec::new()))
.await
.unwrap();
// Get the spawn result and verify
let result = spawn_task.await.unwrap();
assert!(
matches!(result, Err(RemoteProcessError::BadResponse)),
matches!(
&result,
Err(RemoteProcessError::TransportError(TransportError::IoError(x)))
if x.kind() == io::ErrorKind::InvalidData
),
"Unexpected result: {:?}",
result
);
}
#[tokio::test]
async fn spawn_should_return_bad_response_if_did_not_get_a_indicator_that_process_started() {
let (mut transport, session) = make_session();
async fn spawn_should_return_invalid_data_if_did_not_get_a_indicator_that_process_started() {
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
@ -354,7 +372,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::Error(Error {
kind: ErrorKind::Other,
description: String::from("some error"),
@ -366,7 +384,11 @@ mod tests {
// Get the spawn result and verify
let result = spawn_task.await.unwrap();
assert!(
matches!(result, Err(RemoteProcessError::BadResponse)),
matches!(
&result,
Err(RemoteProcessError::TransportError(TransportError::IoError(x)))
if x.kind() == io::ErrorKind::InvalidData
),
"Unexpected result: {:?}",
result
);
@ -374,16 +396,19 @@ mod tests {
#[tokio::test]
async fn kill_should_return_error_if_internal_tasks_already_completed() {
let (mut transport, session) = make_session();
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
@ -393,7 +418,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
@ -416,16 +441,19 @@ mod tests {
#[tokio::test]
async fn kill_should_send_proc_kill_request_and_then_cause_stdin_forwarding_to_close() {
let (mut transport, session) = make_session();
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
@ -435,7 +463,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
@ -469,16 +497,19 @@ mod tests {
#[tokio::test]
async fn stdin_should_be_forwarded_from_receiver_field() {
let (mut transport, session) = make_session();
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
@ -488,7 +519,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
@ -521,16 +552,19 @@ mod tests {
#[tokio::test]
async fn stdout_should_be_forwarded_to_receiver_field() {
let (mut transport, session) = make_session();
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
@ -540,7 +574,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
@ -552,7 +586,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
req.id,
vec![ResponseData::ProcStdout {
id,
data: String::from("some out"),
@ -567,16 +601,19 @@ mod tests {
#[tokio::test]
async fn stderr_should_be_forwarded_to_receiver_field() {
let (mut transport, session) = make_session();
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
@ -586,7 +623,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
@ -598,7 +635,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
req.id,
vec![ResponseData::ProcStderr {
id,
data: String::from("some err"),
@ -613,16 +650,19 @@ mod tests {
#[tokio::test]
async fn wait_should_return_error_if_internal_tasks_fail() {
let (mut transport, session) = make_session();
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
@ -632,7 +672,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
@ -652,16 +692,19 @@ mod tests {
#[tokio::test]
async fn wait_should_return_error_if_connection_terminates_before_receiving_done_response() {
let (mut transport, session) = make_session();
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
@ -671,7 +714,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
@ -679,6 +722,10 @@ mod tests {
// Receive the process and then terminate session connection
let proc = spawn_task.await.unwrap().unwrap();
// Ensure that the spawned task gets a chance to wait on stdout/stderr
tokio::task::yield_now().await;
drop(transport);
// Ensure that the other tasks are cancelled before continuing
@ -694,16 +741,19 @@ mod tests {
#[tokio::test]
async fn receiving_done_response_should_result_in_wait_returning_exit_information() {
let (mut transport, session) = make_session();
let (mut transport, mut session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
&mut session,
String::from("cmd"),
vec![String::from("arg")],
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
@ -713,7 +763,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
Some(req.id),
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
@ -727,7 +777,7 @@ mod tests {
transport
.send(Response::new(
"test-tenant",
None,
req.id,
vec![ResponseData::ProcDone {
id,
success: false,

@ -0,0 +1,435 @@
use crate::{
client::{RemoteProcess, RemoteProcessError, Session},
data::{DirEntry, Error as Failure, FileType, Request, RequestData, ResponseData},
net::TransportError,
};
use derive_more::{Display, Error, From};
use std::{future::Future, path::PathBuf, pin::Pin};
/// Represents an error that can occur related to convenience functions tied to a [`Session`]
#[derive(Debug, Display, Error, From)]
pub enum SessionExtError {
/// Occurs when the remote action fails
Failure(#[error(not(source))] Failure),
/// Occurs when a transport error is encountered
TransportError(TransportError),
/// Occurs when receiving a response that was not expected
MismatchedResponse,
}
pub type AsyncReturn<'a, T, E = SessionExtError> =
Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>;
/// Represents metadata about some path on a remote machine
pub struct Metadata {
pub file_type: FileType,
pub len: u64,
pub readonly: bool,
pub canonicalized_path: Option<PathBuf>,
pub accessed: Option<u128>,
pub created: Option<u128>,
pub modified: Option<u128>,
}
/// Provides convenience functions on top of a [`Session`]
pub trait SessionExt {
/// Appends to a remote file using the data from a collection of bytes
fn append_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()>;
/// Appends to a remote file using the data from a string
fn append_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()>;
/// Copies a remote file or directory from src to dst
fn copy(
&mut self,
tenant: impl Into<String>,
src: impl Into<PathBuf>,
dst: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()>;
/// Creates a remote directory, optionally creating all parent components if specified
fn create_dir(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
all: bool,
) -> AsyncReturn<'_, ()>;
/// Checks if a path exists on a remote machine
fn exists(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, bool>;
/// Retrieves metadata about a path on a remote machine
fn metadata(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
canonicalize: bool,
resolve_file_type: bool,
) -> AsyncReturn<'_, Metadata>;
/// Reads entries from a directory, returning a tuple of directory entries and failures
fn read_dir(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
depth: usize,
absolute: bool,
canonicalize: bool,
include_root: bool,
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)>;
/// Reads a remote file as a collection of bytes
fn read_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, Vec<u8>>;
/// Returns a remote file as a string
fn read_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, String>;
/// Removes a remote file or directory, supporting removal of non-empty directories if
/// force is true
fn remove(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
force: bool,
) -> AsyncReturn<'_, ()>;
/// Renames a remote file or directory from src to dst
fn rename(
&mut self,
tenant: impl Into<String>,
src: impl Into<PathBuf>,
dst: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()>;
/// Spawns a process on the remote machine
fn spawn(
&mut self,
tenant: impl Into<String>,
cmd: impl Into<String>,
args: Vec<String>,
) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError>;
/// Writes a remote file with the data from a collection of bytes
fn write_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()>;
/// Writes a remote file with the data from a string
fn write_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()>;
}
macro_rules! make_body {
($self:expr, $tenant:expr, $data:expr, @ok) => {
make_body!($self, $tenant, $data, |data| {
if data.is_ok() {
Ok(())
} else {
Err(SessionExtError::MismatchedResponse)
}
})
};
($self:expr, $tenant:expr, $data:expr, $and_then:expr) => {{
let req = Request::new($tenant, vec![$data]);
Box::pin(async move {
$self
.send(req)
.await
.map_err(SessionExtError::from)
.and_then(|res| {
if res.payload.len() == 1 {
Ok(res.payload.into_iter().next().unwrap())
} else {
Err(SessionExtError::MismatchedResponse)
}
})
.and_then($and_then)
})
}};
}
impl SessionExt for Session {
/// Appends to a remote file using the data from a collection of bytes
fn append_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::FileAppend { path: path.into(), data: data.into() },
@ok
)
}
/// Appends to a remote file using the data from a string
fn append_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::FileAppendText { path: path.into(), text: data.into() },
@ok
)
}
/// Copies a remote file or directory from src to dst
fn copy(
&mut self,
tenant: impl Into<String>,
src: impl Into<PathBuf>,
dst: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::Copy { src: src.into(), dst: dst.into() },
@ok
)
}
/// Creates a remote directory, optionally creating all parent components if specified
fn create_dir(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
all: bool,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::DirCreate { path: path.into(), all },
@ok
)
}
/// Checks if a path exists on a remote machine
fn exists(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, bool> {
make_body!(
self,
tenant,
RequestData::Exists { path: path.into() },
|data| match data {
ResponseData::Exists(x) => Ok(x),
_ => Err(SessionExtError::MismatchedResponse),
}
)
}
/// Retrieves metadata about a path on a remote machine
fn metadata(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
canonicalize: bool,
resolve_file_type: bool,
) -> AsyncReturn<'_, Metadata> {
make_body!(
self,
tenant,
RequestData::Metadata {
path: path.into(),
canonicalize,
resolve_file_type
},
|data| match data {
ResponseData::Metadata {
canonicalized_path,
file_type,
len,
readonly,
accessed,
created,
modified,
} => Ok(Metadata {
canonicalized_path,
file_type,
len,
readonly,
accessed,
created,
modified,
}),
_ => Err(SessionExtError::MismatchedResponse),
}
)
}
/// Reads entries from a directory, returning a tuple of directory entries and failures
fn read_dir(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
depth: usize,
absolute: bool,
canonicalize: bool,
include_root: bool,
) -> AsyncReturn<'_, (Vec<DirEntry>, Vec<Failure>)> {
make_body!(
self,
tenant,
RequestData::DirRead {
path: path.into(),
depth,
absolute,
canonicalize,
include_root
},
|data| match data {
ResponseData::DirEntries { entries, errors } => Ok((entries, errors)),
_ => Err(SessionExtError::MismatchedResponse),
}
)
}
/// Reads a remote file as a collection of bytes
fn read_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, Vec<u8>> {
make_body!(
self,
tenant,
RequestData::FileRead { path: path.into() },
|data| match data {
ResponseData::Blob { data } => Ok(data),
_ => Err(SessionExtError::MismatchedResponse),
}
)
}
/// Returns a remote file as a string
fn read_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
) -> AsyncReturn<'_, String> {
make_body!(
self,
tenant,
RequestData::FileReadText { path: path.into() },
|data| match data {
ResponseData::Text { data } => Ok(data),
_ => Err(SessionExtError::MismatchedResponse),
}
)
}
/// Removes a remote file or directory, supporting removal of non-empty directories if
/// force is true
fn remove(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
force: bool,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::Remove { path: path.into(), force },
@ok
)
}
/// Renames a remote file or directory from src to dst
fn rename(
&mut self,
tenant: impl Into<String>,
src: impl Into<PathBuf>,
dst: impl Into<PathBuf>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::Rename { src: src.into(), dst: dst.into() },
@ok
)
}
/// Spawns a process on the remote machine
fn spawn(
&mut self,
tenant: impl Into<String>,
cmd: impl Into<String>,
args: Vec<String>,
) -> AsyncReturn<'_, RemoteProcess, RemoteProcessError> {
let tenant = tenant.into();
let cmd = cmd.into();
Box::pin(async move { RemoteProcess::spawn(tenant, self, cmd, args).await })
}
/// Writes a remote file with the data from a collection of bytes
fn write_file(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<Vec<u8>>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::FileWrite { path: path.into(), data: data.into() },
@ok
)
}
/// Writes a remote file with the data from a string
fn write_file_text(
&mut self,
tenant: impl Into<String>,
path: impl Into<PathBuf>,
data: impl Into<String>,
) -> AsyncReturn<'_, ()> {
make_body!(
self,
tenant,
RequestData::FileWriteText { path: path.into(), text: data.into() },
@ok
)
}
}

@ -0,0 +1,84 @@
use crate::{client::utils, data::Response};
use std::{collections::HashMap, time::Duration};
use tokio::{io, sync::mpsc};
pub struct PostOffice {
mailboxes: HashMap<usize, mpsc::Sender<Response>>,
}
impl PostOffice {
pub fn new() -> Self {
Self {
mailboxes: HashMap::new(),
}
}
/// Creates a new mailbox using the given id and buffer size for maximum messages
pub fn make_mailbox(&mut self, id: usize, buffer: usize) -> Mailbox {
let (tx, rx) = mpsc::channel(buffer);
self.mailboxes.insert(id, tx);
Mailbox { id, rx }
}
/// Delivers a response to appropriate mailbox, returning false if no mailbox is found
/// for the response or if the mailbox is no longer receiving responses
pub async fn deliver(&mut self, res: Response) -> bool {
let id = res.origin_id;
let success = if let Some(tx) = self.mailboxes.get_mut(&id) {
tx.send(res).await.is_ok()
} else {
false
};
// If failed, we want to remove the mailbox sender as it is no longer valid
if !success {
self.mailboxes.remove(&id);
}
success
}
/// Removes all mailboxes from post office that are closed
pub fn prune_mailboxes(&mut self) {
self.mailboxes.retain(|_, tx| !tx.is_closed())
}
/// Closes out all mailboxes by removing the mailboxes delivery trackers internally
pub fn clear_mailboxes(&mut self) {
self.mailboxes.clear();
}
}
pub struct Mailbox {
/// Represents id associated with the mailbox
id: usize,
/// Underlying mailbox storage
rx: mpsc::Receiver<Response>,
}
impl Mailbox {
/// Represents id associated with the mailbox
pub fn id(&self) -> usize {
self.id
}
/// Receives next response in mailbox
pub async fn next(&mut self) -> Option<Response> {
self.rx.recv().await
}
/// Receives next response in mailbox, waiting up to duration before timing out
pub async fn next_timeout(&mut self, duration: Duration) -> io::Result<Option<Response>> {
utils::timeout(duration, self.next()).await
}
/// Closes the mailbox such that it will not receive any more responses
///
/// Any responses already in the mailbox will still be returned via `next`
pub async fn close(&mut self) {
self.rx.close()
}
}

@ -1,51 +1,55 @@
use crate::{
client::utils,
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
constants::CLIENT_MAILBOX_CAPACITY,
data::{Request, Response},
net::{Codec, DataStream, Transport, TransportError, TransportWriteHalf},
net::{Codec, DataStream, Transport, TransportError},
};
use log::*;
use std::{
collections::HashMap,
convert,
net::SocketAddr,
sync::{Arc, Mutex},
ops::{Deref, DerefMut},
sync::{Arc, Weak},
};
use tokio::{
io,
net::TcpStream,
sync::{mpsc, oneshot},
sync::{mpsc, Mutex},
task::{JoinError, JoinHandle},
time::Duration,
};
mod ext;
pub use ext::SessionExt;
mod info;
pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError};
type Callbacks = Arc<Mutex<HashMap<usize, oneshot::Sender<Response>>>>;
mod mailbox;
pub use mailbox::Mailbox;
use mailbox::PostOffice;
/// Represents a session with a remote server that can be used to send requests & receive responses
pub struct Session<T, U>
where
T: DataStream,
U: Codec,
{
/// Underlying transport used by session
t_write: TransportWriteHalf<T::Write, U>,
pub struct Session {
/// Used to send requests to a server
channel: SessionChannel,
/// Collection of callbacks to be invoked upon receiving a response to a request
callbacks: Callbacks,
/// Contains the task that is running to send requests to a server
request_task: JoinHandle<()>,
/// Contains the task that is running to receive responses from a server
response_task: JoinHandle<()>,
/// Represents the receiver for broadcasted responses (ones with no callback)
pub broadcast: Option<mpsc::Receiver<Response>>,
/// Contains the task that runs on a timer to prune closed mailboxes
prune_task: JoinHandle<()>,
}
impl<U: Codec + Send + 'static> Session<TcpStream, U> {
impl Session {
/// Connect to a remote TCP server using the provided information
pub async fn tcp_connect(addr: SocketAddr, codec: U) -> io::Result<Self> {
pub async fn tcp_connect<U>(addr: SocketAddr, codec: U) -> io::Result<Self>
where
U: Codec + Send + 'static,
{
let transport = Transport::<TcpStream, U>::connect(addr, codec).await?;
debug!(
"Session has been established with {}",
@ -58,11 +62,14 @@ impl<U: Codec + Send + 'static> Session<TcpStream, U> {
}
/// Connect to a remote TCP server, timing out after duration has passed
pub async fn tcp_connect_timeout(
pub async fn tcp_connect_timeout<U>(
addr: SocketAddr,
codec: U,
duration: Duration,
) -> io::Result<Self> {
) -> io::Result<Self>
where
U: Codec + Send + 'static,
{
utils::timeout(duration, Self::tcp_connect(addr, codec))
.await
.and_then(convert::identity)
@ -70,9 +77,12 @@ impl<U: Codec + Send + 'static> Session<TcpStream, U> {
}
#[cfg(unix)]
impl<U: Codec + Send + 'static> Session<tokio::net::UnixStream, U> {
impl Session {
/// Connect to a proxy unix socket
pub async fn unix_connect(path: impl AsRef<std::path::Path>, codec: U) -> io::Result<Self> {
pub async fn unix_connect<U>(path: impl AsRef<std::path::Path>, codec: U) -> io::Result<Self>
where
U: Codec + Send + 'static,
{
let transport = Transport::<tokio::net::UnixStream, U>::connect(path, codec).await?;
debug!(
"Session has been established with {}",
@ -85,53 +95,50 @@ impl<U: Codec + Send + 'static> Session<tokio::net::UnixStream, U> {
}
/// Connect to a proxy unix socket, timing out after duration has passed
pub async fn unix_connect_timeout(
pub async fn unix_connect_timeout<U>(
path: impl AsRef<std::path::Path>,
codec: U,
duration: Duration,
) -> io::Result<Self> {
) -> io::Result<Self>
where
U: Codec + Send + 'static,
{
utils::timeout(duration, Self::unix_connect(path, codec))
.await
.and_then(convert::identity)
}
}
impl<T, U> Session<T, U>
where
T: DataStream,
U: Codec + Send + 'static,
{
impl Session {
/// Initializes a session using the provided transport
pub fn initialize(transport: Transport<T, U>) -> io::Result<Self> {
let (mut t_read, t_write) = transport.into_split();
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
let (broadcast_tx, broadcast_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
// Start a task that continually checks for responses and triggers callbacks
let callbacks_2 = Arc::clone(&callbacks);
pub fn initialize<T, U>(transport: Transport<T, U>) -> io::Result<Self>
where
T: DataStream,
U: Codec + Send + 'static,
{
let (mut t_read, mut t_write) = transport.into_split();
let post_office = Arc::new(Mutex::new(PostOffice::new()));
let weak_post_office = Arc::downgrade(&post_office);
// Start a task that continually checks for responses and delivers them using the
// post office
let response_task = tokio::spawn(async move {
loop {
match t_read.receive::<Response>().await {
Ok(Some(res)) => {
trace!("Incoming response: {:?}", res);
let maybe_callback = res
.origin_id
.as_ref()
.and_then(|id| callbacks_2.lock().unwrap().remove(id));
// If there is an origin to this response, trigger the callback
if let Some(tx) = maybe_callback {
trace!("Callback exists for response! Triggering!");
if let Err(res) = tx.send(res) {
error!("Failed to trigger callback for response {}", res.id);
}
// Otherwise, this goes into the junk draw of response handlers
} else {
trace!("Callback missing for response! Broadcasting!");
if let Err(x) = broadcast_tx.send(res).await {
error!("Failed to trigger broadcast: {}", x);
}
let res_id = res.id;
let res_origin_id = res.origin_id;
// Try to send response to appropriate mailbox
// NOTE: We don't log failures as errors as using fire(...) for a
// session is valid and would not have a mailbox
if !post_office.lock().await.deliver(res).await {
trace!(
"Response {} has no mailbox for origin {}",
res_id,
res_origin_id
);
}
}
Ok(None) => {
@ -144,47 +151,144 @@ where
}
}
}
// Clean up remaining mailbox senders
post_office.lock().await.clear_mailboxes();
});
let (tx, mut rx) = mpsc::channel::<Request>(1);
let request_task = tokio::spawn(async move {
while let Some(req) = rx.recv().await {
if let Err(x) = t_write.send(req).await {
error!("Failed to send request to server: {}", x);
break;
}
}
});
// Create a task that runs once a minute and prunes mailboxes
let post_office = Weak::clone(&weak_post_office);
let prune_task = tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(60)).await;
if let Some(post_office) = Weak::upgrade(&post_office) {
post_office.lock().await.prune_mailboxes();
} else {
break;
}
}
});
let channel = SessionChannel {
tx,
post_office: weak_post_office,
};
Ok(Self {
t_write,
callbacks,
broadcast: Some(broadcast_rx),
channel,
request_task,
response_task,
prune_task,
})
}
}
impl<T, U> Session<T, U>
where
T: DataStream,
U: Codec,
{
impl Session {
/// Waits for the session to terminate, which results when the receiving end of the network
/// connection is closed (or the session is shutdown)
pub async fn wait(self) -> Result<(), JoinError> {
self.response_task.await
self.prune_task.abort();
tokio::try_join!(self.request_task, self.response_task).map(|_| ())
}
/// Abort the session's current connection by forcing its response task to shutdown
pub fn abort(&self) {
self.response_task.abort()
self.request_task.abort();
self.response_task.abort();
self.prune_task.abort();
}
/// Sends a request and waits for a response
pub async fn send(&mut self, req: Request) -> Result<Response, TransportError> {
trace!("Sending request: {:?}", req);
/// Clones the underlying channel for requests and returns the cloned instance
pub fn clone_channel(&self) -> SessionChannel {
self.channel.clone()
}
}
impl Deref for Session {
type Target = SessionChannel;
fn deref(&self) -> &Self::Target {
&self.channel
}
}
// First, add a callback that will trigger when we get the response for this request
let (tx, rx) = oneshot::channel();
self.callbacks.lock().unwrap().insert(req.id, tx);
impl DerefMut for Session {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.channel
}
}
impl From<Session> for SessionChannel {
fn from(session: Session) -> Self {
session.channel
}
}
/// Represents a sender of requests tied to a session, holding onto a weak reference of
/// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped,
/// any sent request will no longer be able to receive responses
#[derive(Clone)]
pub struct SessionChannel {
/// Used to send requests to a server
tx: mpsc::Sender<Request>,
/// Collection of mailboxes for receiving responses to requests
post_office: Weak<Mutex<PostOffice>>,
}
impl SessionChannel {
/// Returns true if no more requests can be transferred
pub fn is_closed(&self) -> bool {
self.tx.is_closed()
}
/// Sends a request and returns a mailbox that can receive one or more responses, failing if
/// unable to send a request or if the session's receiving line to the remote server has
/// already been severed
pub async fn mail(&mut self, req: Request) -> Result<Mailbox, TransportError> {
trace!("Mailing request: {:?}", req);
// First, create a mailbox using the request's id
let mailbox = Weak::upgrade(&self.post_office)
.ok_or_else(|| {
TransportError::IoError(io::Error::new(
io::ErrorKind::NotConnected,
"Session's post office is no longer available",
))
})?
.lock()
.await
.make_mailbox(req.id, CLIENT_MAILBOX_CAPACITY);
// Second, send the request
self.t_write.send(req).await?;
self.fire(req).await?;
// Third, wait for the response
rx.await
.map_err(|x| TransportError::from(io::Error::new(io::ErrorKind::ConnectionAborted, x)))
// Third, return mailbox
Ok(mailbox)
}
/// Sends a request and waits for a response, failing if unable to send a request or if
/// the session's receiving line to the remote server has already been severed
pub async fn send(&mut self, req: Request) -> Result<Response, TransportError> {
trace!("Sending request: {:?}", req);
// Send mail and get back a mailbox
let mut mailbox = self.mail(req).await?;
// Wait for first response, and then drop the mailbox
mailbox.next().await.ok_or_else(|| {
TransportError::IoError(io::Error::from(io::ErrorKind::ConnectionAborted))
})
}
/// Sends a request and waits for a response, timing out after duration has passed
@ -199,12 +303,14 @@ where
.and_then(convert::identity)
}
/// Sends a request without waiting for a response
///
/// Any response that would be received gets sent over the broadcast channel instead
/// Sends a request without waiting for a response; this method is able to be used even
/// if the session's receiving line to the remote server has been severed
pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> {
trace!("Firing off request: {:?}", req);
self.t_write.send(req).await
self.tx
.send(req)
.await
.map_err(|x| TransportError::IoError(io::Error::new(io::ErrorKind::BrokenPipe, x)))
}
/// Sends a request without waiting for a response, timing out after duration has passed
@ -229,6 +335,40 @@ mod tests {
};
use std::time::Duration;
#[tokio::test]
async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
let res = Response::new(TENANT, req.id, vec![ResponseData::Ok]);
let mut mailbox = session.mail(req).await.unwrap();
// Get first response
match tokio::join!(mailbox.next(), t2.send(res.clone())) {
(Some(actual), _) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
}
// Get second response
match tokio::join!(mailbox.next(), t2.send(res.clone())) {
(Some(actual), _) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
}
// Trigger the mailbox to wait BEFORE closing our transport to ensure that
// we don't get stuck if the mailbox was already waiting
let next_task = tokio::spawn(async move { mailbox.next().await });
tokio::task::yield_now().await;
drop(t2);
match next_task.await {
Ok(None) => {}
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn send_should_wait_until_response_received() {
let (t1, mut t2) = Transport::make_pair();
@ -237,7 +377,7 @@ mod tests {
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
let res = Response::new(
TENANT,
Some(req.id),
req.id,
vec![ResponseData::ProcEntries {
entries: Vec::new(),
}],

@ -1,6 +1,5 @@
/// Capacity associated with a client broadcasting its received messages that
/// do not have a callback associated
pub const CLIENT_BROADCAST_CHANNEL_CAPACITY: usize = 10000;
/// Capacity associated with a client mailboxes for receiving multiple responses to a request
pub const CLIENT_MAILBOX_CAPACITY: usize = 10000;
/// Represents the maximum size (in bytes) that data will be read from pipes
/// per individual `read` call

@ -257,9 +257,9 @@ pub struct Response {
/// A unique id associated with the response
pub id: usize,
/// The id of the originating request, if there was one
/// (some responses are sent unprompted)
pub origin_id: Option<usize>,
/// The id of the originating request that yielded this response
/// (more than one response may have same origin)
pub origin_id: usize,
/// The main payload containing a collection of data comprising one or more results
pub payload: Vec<ResponseData>,
@ -267,11 +267,7 @@ pub struct Response {
impl Response {
/// Creates a new response, generating a unique id for it
pub fn new(
tenant: impl Into<String>,
origin_id: Option<usize>,
payload: Vec<ResponseData>,
) -> Self {
pub fn new(tenant: impl Into<String>, origin_id: usize, payload: Vec<ResponseData>) -> Self {
let id = rand::random();
Self {
tenant: tenant.into(),
@ -486,7 +482,8 @@ impl From<tokio::task::JoinError> for ResponseData {
}
/// General purpose error type that can be sent across the wire
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[display(fmt = "{}: {}", kind, description)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct Error {
/// Label describing the kind of error

@ -1,9 +1,9 @@
mod client;
pub use client::{
LspContent, LspContentParseError, LspData, LspDataParseError, LspHeader, LspHeaderParseError,
LspSessionInfoError, RemoteLspProcess, RemoteLspStderr, RemoteLspStdin, RemoteLspStdout,
RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout, Session,
SessionInfo, SessionInfoFile, SessionInfoParseError,
LspSessionInfoError, Mailbox, RemoteLspProcess, RemoteLspStderr, RemoteLspStdin,
RemoteLspStdout, RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout,
Session, SessionInfo, SessionInfoFile, SessionInfoParseError,
};
mod constants;

@ -10,7 +10,9 @@ use futures::future;
use log::*;
use std::{
env,
future::Future,
path::{Path, PathBuf},
pin::Pin,
process::Stdio,
sync::Arc,
time::SystemTime,
@ -22,8 +24,8 @@ use tokio::{
};
use walkdir::WalkDir;
pub type Reply = mpsc::Sender<Response>;
type HState = Arc<Mutex<State>>;
type ReplyRet = Pin<Box<dyn Future<Output = bool> + Send + 'static>>;
#[derive(Debug, Display, Error, From)]
pub enum ServerError {
@ -60,15 +62,17 @@ pub(super) async fn process(
conn_id: usize,
state: HState,
req: Request,
tx: Reply,
tx: mpsc::Sender<Response>,
) -> Result<(), mpsc::error::SendError<Response>> {
async fn inner(
tenant: Arc<String>,
async fn inner<F>(
conn_id: usize,
state: HState,
data: RequestData,
tx: Reply,
) -> Result<Outgoing, ServerError> {
reply: F,
) -> Result<Outgoing, ServerError>
where
F: FnMut(Vec<ResponseData>) -> ReplyRet + Clone + Send + 'static,
{
match data {
RequestData::FileRead { path } => file_read(path).await,
RequestData::FileReadText { path } => file_read_text(path).await,
@ -93,9 +97,7 @@ pub(super) async fn process(
canonicalize,
resolve_file_type,
} => metadata(path, canonicalize, resolve_file_type).await,
RequestData::ProcRun { cmd, args } => {
proc_run(tenant.to_string(), conn_id, state, tx, cmd, args).await
}
RequestData::ProcRun { cmd, args } => proc_run(conn_id, state, reply, cmd, args).await,
RequestData::ProcKill { id } => proc_kill(state, id).await,
RequestData::ProcStdin { id, data } => proc_stdin(state, id, data).await,
RequestData::ProcList {} => proc_list(state).await,
@ -103,16 +105,24 @@ pub(super) async fn process(
}
}
let tenant = Arc::new(req.tenant.clone());
let reply = {
let origin_id = req.id;
let tenant = req.tenant.clone();
let tx_2 = tx.clone();
move |payload: Vec<ResponseData>| -> ReplyRet {
let tx = tx_2.clone();
let res = Response::new(tenant.to_string(), origin_id, payload);
Box::pin(async move { tx.send(res).await.is_ok() })
}
};
// Build up a collection of tasks to run independently
let mut payload_tasks = Vec::new();
for data in req.payload {
let tenant_2 = Arc::clone(&tenant);
let state_2 = Arc::clone(&state);
let tx_2 = tx.clone();
let reply_2 = reply.clone();
payload_tasks.push(tokio::spawn(async move {
match inner(tenant_2, conn_id, state_2, data, tx_2).await {
match inner(conn_id, state_2, data, reply_2).await {
Ok(outgoing) => outgoing,
Err(x) => Outgoing::from(ResponseData::from(x)),
}
@ -135,7 +145,7 @@ pub(super) async fn process(
.collect();
let payload = outgoing.into_iter().map(|x| x.data).collect();
let res = Response::new(req.tenant, Some(req.id), payload);
let res = Response::new(req.tenant, req.id, payload);
// Send out our primary response from processing the request
let result = tx.send(res).await;
@ -407,14 +417,16 @@ async fn metadata(
}))
}
async fn proc_run(
tenant: String,
async fn proc_run<F>(
conn_id: usize,
state: HState,
tx: Reply,
reply: F,
cmd: String,
args: Vec<String>,
) -> Result<Outgoing, ServerError> {
) -> Result<Outgoing, ServerError>
where
F: FnMut(Vec<ResponseData>) -> ReplyRet + Clone + Send + 'static,
{
let id = rand::random();
let mut child = Command::new(cmd.to_string())
@ -430,21 +442,16 @@ async fn proc_run(
let post_hook = Box::new(move |mut state_lock: MutexGuard<'_, State>| {
// Spawn a task that sends stdout as a response
let tx_2 = tx.clone();
let tenant_2 = tenant.clone();
let mut stdout = child.stdout.take().unwrap();
let mut reply_2 = reply.clone();
let stdout_task = tokio::spawn(async move {
let mut buf: [u8; MAX_PIPE_CHUNK_SIZE] = [0; MAX_PIPE_CHUNK_SIZE];
loop {
match stdout.read(&mut buf).await {
Ok(n) if n > 0 => match String::from_utf8(buf[..n].to_vec()) {
Ok(data) => {
let res = Response::new(
tenant_2.as_str(),
None,
vec![ResponseData::ProcStdout { id, data }],
);
if tx_2.send(res).await.is_err() {
let payload = vec![ResponseData::ProcStdout { id, data }];
if !reply_2(payload).await {
error!("<Conn @ {}> Stdout channel closed", conn_id);
break;
}
@ -468,21 +475,16 @@ async fn proc_run(
});
// Spawn a task that sends stderr as a response
let tx_2 = tx.clone();
let tenant_2 = tenant.clone();
let mut stderr = child.stderr.take().unwrap();
let mut reply_2 = reply.clone();
let stderr_task = tokio::spawn(async move {
let mut buf: [u8; MAX_PIPE_CHUNK_SIZE] = [0; MAX_PIPE_CHUNK_SIZE];
loop {
match stderr.read(&mut buf).await {
Ok(n) if n > 0 => match String::from_utf8(buf[..n].to_vec()) {
Ok(data) => {
let res = Response::new(
tenant_2.as_str(),
None,
vec![ResponseData::ProcStderr { id, data }],
);
if tx_2.send(res).await.is_err() {
let payload = vec![ResponseData::ProcStderr { id, data }];
if !reply_2(payload).await {
error!("<Conn @ {}> Stderr channel closed", conn_id);
break;
}
@ -524,6 +526,7 @@ async fn proc_run(
// kill the process when triggered
let state_2 = Arc::clone(&state);
let (kill_tx, kill_rx) = oneshot::channel();
let mut reply_2 = reply.clone();
let wait_task = tokio::spawn(async move {
tokio::select! {
status = child.wait() => {
@ -547,12 +550,8 @@ async fn proc_run(
Ok(status) => {
let success = status.success();
let code = status.code();
let res = Response::new(
tenant.as_str(),
None,
vec![ResponseData::ProcDone { id, success, code }]
);
if tx.send(res).await.is_err() {
let payload = vec![ResponseData::ProcDone { id, success, code }];
if !reply_2(payload).await {
error!(
"<Conn @ {}> Failed to send done for process {}!",
conn_id,
@ -561,8 +560,8 @@ async fn proc_run(
}
}
Err(x) => {
let res = Response::new(tenant.as_str(), None, vec![ResponseData::from(x)]);
if tx.send(res).await.is_err() {
let payload = vec![ResponseData::from(x)];
if !reply_2(payload).await {
error!(
"<Conn @ {}> Failed to send error for waiting on process {}!",
conn_id,
@ -594,10 +593,8 @@ async fn proc_run(
state_2.lock().await.remove_process(conn_id, id);
let res = Response::new(tenant.as_str(), None, vec![ResponseData::ProcDone {
id, success: false, code: None
}]);
if tx.send(res).await.is_err() {
let payload = vec![ResponseData::ProcDone { id, success: false, code: None }];
if !reply_2(payload).await {
error!("<Conn @ {}> Failed to send done for process {}!", conn_id, id);
}
}

@ -1,16 +1,15 @@
use crate::{
client::Session,
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, RequestData, Response, ResponseData},
net::{Codec, DataStream, Transport, TransportReadHalf, TransportWriteHalf},
client::{Session, SessionChannel},
data::{Request, RequestData, ResponseData},
net::{Codec, DataStream, Transport},
server::utils::{ConnTracker, ShutdownTask},
};
use futures::stream::{Stream, StreamExt};
use log::*;
use std::{collections::HashMap, marker::Unpin, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
sync::{mpsc, oneshot, Mutex},
io,
sync::{oneshot, Mutex},
task::{JoinError, JoinHandle},
time::Duration,
};
@ -19,106 +18,33 @@ use tokio::{
/// actual server
pub struct RelayServer {
accept_task: JoinHandle<()>,
broadcast_task: JoinHandle<()>,
forward_task: JoinHandle<()>,
conns: Arc<Mutex<HashMap<usize, Conn>>>,
}
impl RelayServer {
pub fn initialize<T1, T2, U1, U2, S>(
mut session: Session<T1, U1>,
pub fn initialize<T, U, S>(
session: Session,
mut stream: S,
shutdown_after: Option<Duration>,
) -> io::Result<Self>
where
T1: DataStream + 'static,
T2: DataStream + Send + 'static,
U1: Codec + Send + 'static,
U2: Codec + Send + 'static,
S: Stream<Item = Transport<T2, U2>> + Send + Unpin + 'static,
T: DataStream + Send + 'static,
U: Codec + Send + 'static,
S: Stream<Item = Transport<T, U>> + Send + Unpin + 'static,
{
let conns: Arc<Mutex<HashMap<usize, Conn>>> = Arc::new(Mutex::new(HashMap::new()));
// Spawn task to send server responses to the appropriate connections
let conns_2 = Arc::clone(&conns);
debug!("Spawning response broadcast task");
let mut broadcast = session.broadcast.take().unwrap();
let (shutdown_broadcast_tx, mut shutdown_broadcast_rx) = mpsc::channel::<()>(1);
let broadcast_task = tokio::spawn(async move {
loop {
let res = tokio::select! {
maybe_res = broadcast.recv() => {
match maybe_res {
Some(res) => res,
None => break,
}
}
_ = shutdown_broadcast_rx.recv() => {
break;
}
};
// Search for all connections with a tenant that matches the response's tenant
for conn in conns_2.lock().await.values_mut() {
if conn.state.lock().await.tenant.as_deref() == Some(res.tenant.as_str()) {
debug!(
"Forwarding response of type{} {} to connection {}",
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string(),
conn.id
);
if let Err(x) = conn.forward_response(res).await {
error!("Failed to pass forwarding message: {}", x);
}
// NOTE: We assume that tenant is unique, so we can break after
// forwarding the message to the first matching tenant
break;
}
}
}
});
// Spawn task to send to the server requests from connections
debug!("Spawning request forwarding task");
let (req_tx, mut req_rx) = mpsc::channel::<Request>(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (shutdown_forward_tx, mut shutdown_forward_rx) = mpsc::channel::<()>(1);
let forward_task = tokio::spawn(async move {
loop {
let req = tokio::select! {
maybe_req = req_rx.recv() => {
match maybe_req {
Some(req) => req,
None => break,
}
}
_ = shutdown_forward_rx.recv() => {
break;
}
};
debug!(
"Forwarding request of type{} {} to server",
if req.payload.len() > 1 { "s" } else { "" },
req.to_payload_type_string()
);
if let Err(x) = session.fire(req).await {
error!("Session failed to send request: {:?}", x);
break;
}
}
});
let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after);
let conns_2 = Arc::clone(&conns);
let accept_task = tokio::spawn(async move {
let inner = async move {
loop {
let channel = session.clone_channel();
match stream.next().await {
Some(transport) => {
let result = Conn::initialize(
transport,
req_tx.clone(),
channel,
tracker.as_ref().map(Arc::clone),
)
.await;
@ -149,34 +75,19 @@ impl RelayServer {
},
None => inner.await,
}
// Doesn't matter if we send or drop these as long as they persist until this
// task is completed, so just drop
drop(shutdown_broadcast_tx);
drop(shutdown_forward_tx);
});
Ok(Self {
accept_task,
broadcast_task,
forward_task,
conns,
})
Ok(Self { accept_task, conns })
}
/// Waits for the server to terminate
pub async fn wait(self) -> Result<(), JoinError> {
match tokio::try_join!(self.accept_task, self.broadcast_task, self.forward_task) {
Ok(_) => Ok(()),
Err(x) => Err(x),
}
self.accept_task.await
}
/// Aborts the server by aborting the internal tasks and current connections
pub async fn abort(&self) {
self.accept_task.abort();
self.broadcast_task.abort();
self.forward_task.abort();
self.conns
.lock()
.await
@ -187,24 +98,13 @@ impl RelayServer {
struct Conn {
id: usize,
req_task: JoinHandle<()>,
res_task: JoinHandle<()>,
_cleanup_task: JoinHandle<()>,
res_tx: mpsc::Sender<Response>,
state: Arc<Mutex<ConnState>>,
}
/// Represents state associated with a connection
#[derive(Default)]
struct ConnState {
tenant: Option<String>,
processes: Vec<usize>,
conn_task: JoinHandle<()>,
}
impl Conn {
pub async fn initialize<T, U>(
transport: Transport<T, U>,
req_tx: mpsc::Sender<Request>,
channel: SessionChannel,
ct: Option<Arc<Mutex<ConnTracker>>>,
) -> io::Result<Self>
where
@ -215,59 +115,14 @@ impl Conn {
// is not guaranteed to have an identifiable string
let id: usize = rand::random();
let (t_read, t_write) = transport.into_split();
// Used to alert our response task of the connection's tenant name
// based on the first
let (tenant_tx, tenant_rx) = oneshot::channel();
// Create a state we use to keep track of connection-specific data
debug!("<Conn @ {}> Initializing internal state", id);
let state = Arc::new(Mutex::new(ConnState::default()));
// Mark that we have a new connection
if let Some(ct) = ct.as_ref() {
ct.lock().await.increment();
}
// Spawn task to continually receive responses from the session that
// may or may not be relevant to the connection, which will filter
// by tenant and then along any response that matches
let (res_tx, res_rx) = mpsc::channel::<Response>(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (res_task_tx, res_task_rx) = oneshot::channel();
let state_2 = Arc::clone(&state);
let res_task = tokio::spawn(async move {
handle_conn_outgoing(id, state_2, t_write, tenant_rx, res_rx).await;
let _ = res_task_tx.send(());
});
// Spawn task to continually read requests from connection and forward
// them along to be sent via the session
let req_tx = req_tx.clone();
let (req_task_tx, req_task_rx) = oneshot::channel();
let state_2 = Arc::clone(&state);
let req_task = tokio::spawn(async move {
handle_conn_incoming(id, state_2, t_read, tenant_tx, req_tx).await;
let _ = req_task_tx.send(());
});
let conn_task = spawn_conn_handler(id, transport, channel, ct).await;
let _cleanup_task = tokio::spawn(async move {
let _ = tokio::join!(req_task_rx, res_task_rx);
if let Some(ct) = ct.as_ref() {
ct.lock().await.decrement();
}
debug!("<Conn @ {}> Disconnected", id);
});
Ok(Self {
id,
req_task,
res_task,
_cleanup_task,
res_tx,
state,
})
Ok(Self { id, conn_task })
}
/// Id associated with the connection
@ -277,153 +132,125 @@ impl Conn {
/// Aborts the connection from the server side
pub fn abort(&self) {
// NOTE: We don't abort the cleanup task as that needs to actually happen
// and will even if these tasks are aborted
self.req_task.abort();
self.res_task.abort();
}
/// Forwards a response back through this connection
pub async fn forward_response(
&mut self,
res: Response,
) -> Result<(), mpsc::error::SendError<Response>> {
self.res_tx.send(res).await
self.conn_task.abort();
}
}
/// Conn::Request -> Session::Fire
async fn handle_conn_incoming<T, U>(
async fn spawn_conn_handler<T, U>(
conn_id: usize,
state: Arc<Mutex<ConnState>>,
mut reader: TransportReadHalf<T, U>,
tenant_tx: oneshot::Sender<String>,
req_tx: mpsc::Sender<Request>,
) where
T: AsyncRead + Unpin,
U: Codec,
transport: Transport<T, U>,
mut channel: SessionChannel,
ct: Option<Arc<Mutex<ConnTracker>>>,
) -> JoinHandle<()>
where
T: DataStream,
U: Codec + Send + 'static,
{
macro_rules! process_req {
($on_success:expr; $done:expr) => {
match reader.receive::<Request>().await {
Ok(Some(req)) => {
$on_success(&req);
if let Err(x) = req_tx.send(req).await {
error!(
"Failed to pass along request received on unix socket: {:?}",
x
);
$done;
}
}
Ok(None) => $done,
Err(x) => {
error!("Failed to receive request from unix stream: {:?}", x);
$done;
}
let (mut t_reader, t_writer) = transport.into_split();
let processes = Arc::new(Mutex::new(Vec::new()));
let t_writer = Arc::new(Mutex::new(t_writer));
let (done_tx, done_rx) = oneshot::channel();
let mut channel_2 = channel.clone();
let processes_2 = Arc::clone(&processes);
let task = tokio::spawn(async move {
loop {
if channel_2.is_closed() {
break;
}
};
}
let mut tenant = None;
// For each request, forward it through the session and monitor all responses
match t_reader.receive::<Request>().await {
Ok(Some(req)) => match channel_2.mail(req).await {
Ok(mut mailbox) => {
let processes = Arc::clone(&processes_2);
let t_writer = Arc::clone(&t_writer);
tokio::spawn(async move {
while let Some(res) = mailbox.next().await {
// Keep track of processes that are started so we can kill them
// when we're done
{
let mut p_lock = processes.lock().await;
for data in res.payload.iter() {
if let ResponseData::ProcStart { id } = *data {
p_lock.push(id);
}
}
}
// NOTE: Have to acquire our first request outside our loop since the oneshot
// sender of the tenant's name is consuming
process_req!(
|req: &Request| {
tenant = Some(req.tenant.clone());
if let Err(x) = tenant_tx.send(req.tenant.clone()) {
error!("Failed to send along acquired tenant name: {:?}", x);
return;
if let Err(x) = t_writer.lock().await.send(res).await {
error!(
"<Conn @ {}> Failed to send response back: {}",
conn_id, x
);
}
}
});
}
Err(x) => error!(
"<Conn @ {}> Failed to pass along request received on unix socket: {:?}",
conn_id, x
),
},
Ok(None) => break,
Err(x) => {
error!(
"<Conn @ {}> Failed to receive request from unix stream: {:?}",
conn_id, x
);
break;
}
}
};
return
);
// Loop and process all additional requests
loop {
process_req!(|_| {}; break);
}
}
// At this point, we have processed at least one request successfully
// and should have the tenant populated. If we had a failure at the
// beginning, we exit the function early via return.
let tenant = tenant.unwrap();
let _ = done_tx.send(());
});
// Perform cleanup if done by sending a request to kill each running process
// debug!("Cleaning conn {} :: killing process {}", conn_id, id);
if let Err(x) = req_tx
.send(Request::new(
tenant.clone(),
state
.lock()
.await
.processes
.iter()
.map(|id| RequestData::ProcKill { id: *id })
.collect(),
))
.await
{
error!("<Conn @ {}> Failed to send kill signals: {}", conn_id, x);
}
}
tokio::spawn(async move {
let _ = done_rx.await;
async fn handle_conn_outgoing<T, U>(
conn_id: usize,
state: Arc<Mutex<ConnState>>,
mut writer: TransportWriteHalf<T, U>,
tenant_rx: oneshot::Receiver<String>,
mut res_rx: mpsc::Receiver<Response>,
) where
T: AsyncWrite + Unpin,
U: Codec,
{
// We wait for the tenant to be identified by the first request
// before processing responses to be sent back; this is easier
// to implement and yields the same result as we would be dropping
// all responses before we know the tenant
if let Ok(tenant) = tenant_rx.await {
debug!("Associated tenant {} with conn {}", tenant, conn_id);
state.lock().await.tenant = Some(tenant.clone());
while let Some(res) = res_rx.recv().await {
debug!(
"Conn {} being sent response of type{} {}",
let p_lock = processes.lock().await;
if !p_lock.is_empty() {
trace!(
"Cleaning conn {} :: killing {} process",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string(),
p_lock.len()
);
// If a new process was started, we want to capture the id and
// associate it with the connection
let ids = res.payload.iter().filter_map(|x| match x {
ResponseData::ProcStart { id } => Some(*id),
_ => None,
});
for id in ids {
debug!("Tracking proc {} for conn {}", id, conn_id);
state.lock().await.processes.push(id);
if let Err(x) = channel
.fire(Request::new(
"relay",
p_lock
.iter()
.map(|id| RequestData::ProcKill { id: *id })
.collect(),
))
.await
{
error!("<Conn @ {}> Failed to send kill signals: {}", conn_id, x);
}
}
if let Err(x) = writer.send(res).await {
error!("Failed to send response through unix connection: {}", x);
break;
}
if let Some(ct) = ct.as_ref() {
ct.lock().await.decrement();
}
}
debug!("<Conn @ {}> Disconnected", conn_id);
});
task
}
#[cfg(test)]
mod tests {
use super::*;
use crate::net::{InmemoryStream, PlainCodec};
use crate::{
data::Response,
net::{InmemoryStream, PlainCodec},
};
use std::{pin::Pin, time::Duration};
use tokio::sync::mpsc;
fn make_session() -> (
Transport<InmemoryStream, PlainCodec>,
Session<InmemoryStream, PlainCodec>,
) {
fn make_session() -> (Transport<InmemoryStream, PlainCodec>, Session) {
let (t1, t2) = Transport::make_pair();
(t1, Session::initialize(t2).unwrap())
}
@ -519,11 +346,16 @@ mod tests {
// Clear out the transport channel (outbound of session)
// NOTE: Because our test stream uses a buffer size of 1, we have to clear out the
// outbound data from the earlier requests before we can send back a response
let _ = transport.receive::<Request>().await.unwrap().unwrap();
let _ = transport.receive::<Request>().await.unwrap().unwrap();
let req_1 = transport.receive::<Request>().await.unwrap().unwrap();
let req_2 = transport.receive::<Request>().await.unwrap().unwrap();
let origin_id = if req_1.tenant == "test-tenant-2" {
req_1.id
} else {
req_2.id
};
// Send a response back to a singular connection based on the tenant
let res = Response::new("test-tenant-2", None, vec![ResponseData::Ok]);
let res = Response::new("test-tenant-2", origin_id, vec![ResponseData::Ok]);
transport.send(res.clone()).await.unwrap();
// Verify that response is only received by a singular connection

@ -106,12 +106,11 @@ impl ExitCodeError for TransportError {
impl ExitCodeError for RemoteProcessError {
fn is_silent(&self) -> bool {
matches!(self, Self::BadResponse)
true
}
fn to_exit_code(&self) -> ExitCode {
match self {
Self::BadResponse => ExitCode::DataErr,
Self::ChannelDead => ExitCode::Unavailable,
Self::TransportError(x) => x.to_exit_code(),
Self::UnexpectedEof => ExitCode::IoError,

@ -38,15 +38,7 @@ fn init_logging(opt: &opt::CommonOpt, is_remote_process: bool) -> flexi_logger::
// For each module, configure logging
for module in modules {
builder.module(
module,
match opt.verbose {
0 => LevelFilter::Warn,
1 => LevelFilter::Info,
2 => LevelFilter::Debug,
_ => LevelFilter::Trace,
},
);
builder.module(module, opt.log_level.to_log_level_filter());
// If quiet, we suppress all logging output
//

@ -40,17 +40,58 @@ impl Opt {
}
}
#[derive(
Copy,
Clone,
Debug,
Display,
PartialEq,
Eq,
IsVariant,
IntoStaticStr,
EnumString,
EnumVariantNames,
)]
#[strum(serialize_all = "snake_case")]
pub enum LogLevel {
Off,
Error,
Warn,
Info,
Debug,
Trace,
}
impl LogLevel {
pub fn to_log_level_filter(self) -> log::LevelFilter {
match self {
Self::Off => log::LevelFilter::Off,
Self::Error => log::LevelFilter::Error,
Self::Warn => log::LevelFilter::Warn,
Self::Info => log::LevelFilter::Info,
Self::Debug => log::LevelFilter::Debug,
Self::Trace => log::LevelFilter::Trace,
}
}
}
/// Contains options that are common across subcommands
#[derive(Debug, StructOpt)]
pub struct CommonOpt {
/// Verbose mode (-v, -vv, -vvv, etc.)
#[structopt(short, long, parse(from_occurrences), global = true)]
pub verbose: u8,
/// Quiet mode, suppresses all logging
/// Quiet mode, suppresses all logging (shortcut for log level off)
#[structopt(short, long, global = true)]
pub quiet: bool,
/// Log level to use throughout the application
#[structopt(
long,
global = true,
case_insensitive = true,
default_value = LogLevel::Info.into(),
possible_values = LogLevel::VARIANTS
)]
pub log_level: LogLevel,
/// Log output to disk instead of stderr
#[structopt(long, global = true)]
pub log_file: Option<PathBuf>,

@ -1,35 +1,38 @@
use crate::{
buf::StringBuf, constants::MAX_PIPE_CHUNK_SIZE, opt::Format, output::ResponseOut, stdin,
};
use distant_core::{Codec, DataStream, Request, RequestData, Response, Session};
use distant_core::{Mailbox, Request, RequestData, Session};
use log::*;
use std::{io, thread};
use std::io;
use structopt::StructOpt;
use tokio::{sync::mpsc, task::JoinHandle};
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
};
/// Represents a wrapper around a session that provides CLI functionality such as reading from
/// stdin and piping results back out to stdout
pub struct CliSession {
_stdin_thread: thread::JoinHandle<()>,
req_task: JoinHandle<()>,
res_task: JoinHandle<io::Result<()>>,
}
impl CliSession {
pub fn new<T, U>(tenant: String, mut session: Session<T, U>, format: Format) -> Self
where
T: DataStream + 'static,
U: Codec + Send + 'static,
{
let (stdin_thread, stdin_rx) = stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE);
/// Creates a new instance of a session for use in CLI interactions being fed input using
/// the program's stdin
pub fn new_for_stdin(tenant: String, session: Session, format: Format) -> Self {
let (_stdin_thread, stdin_rx) = stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE);
let (exit_tx, exit_rx) = mpsc::channel(1);
let broadcast = session.broadcast.take().unwrap();
let res_task =
tokio::spawn(
async move { process_incoming_responses(broadcast, format, exit_rx).await },
);
Self::new(tenant, session, format, stdin_rx)
}
/// Creates a new instance of a session for use in CLI interactions being fed input using
/// the provided receiver
pub fn new(
tenant: String,
session: Session,
format: Format,
stdin_rx: mpsc::Receiver<String>,
) -> Self {
let map_line = move |line: &str| match format {
Format::Json => serde_json::from_str(line)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x)),
@ -44,61 +47,53 @@ impl CliSession {
}
};
let req_task = tokio::spawn(async move {
process_outgoing_requests(session, stdin_rx, exit_tx, format, map_line).await
process_outgoing_requests(session, stdin_rx, format, map_line).await
});
Self {
_stdin_thread: stdin_thread,
req_task,
res_task,
}
Self { req_task }
}
/// Wait for the cli session to terminate
pub async fn wait(self) -> io::Result<()> {
match tokio::try_join!(self.req_task, self.res_task) {
Ok((_, res)) => res,
match self.req_task.await {
Ok(res) => Ok(res),
Err(x) => Err(io::Error::new(io::ErrorKind::BrokenPipe, x)),
}
}
}
/// Helper function that loops, processing incoming responses not tied to a request to be sent out
/// over stdout/stderr
async fn process_incoming_responses(
mut broadcast: mpsc::Receiver<Response>,
format: Format,
mut exit: mpsc::Receiver<()>,
) -> io::Result<()> {
loop {
tokio::select! {
res = broadcast.recv() => {
match res {
Some(res) => ResponseOut::new(format, res)?.print(),
None => return Ok(()),
/// Helper function that loops, processing incoming responses to a mailbox
async fn process_mailbox(mut mailbox: Mailbox, format: Format, exit: oneshot::Receiver<()>) {
let inner = async move {
while let Some(res) = mailbox.next().await {
match ResponseOut::new(format, res) {
Ok(out) => out.print(),
Err(x) => {
error!("{}", x);
break;
}
}
_ = exit.recv() => {
return Ok(());
}
}
};
tokio::select! {
_ = inner => {}
_ = exit => {}
}
}
/// Helper function that loops, processing outgoing requests created from stdin, and printing out
/// responses
async fn process_outgoing_requests<T, U, F>(
mut session: Session<T, U>,
async fn process_outgoing_requests<F>(
mut session: Session,
mut stdin_rx: mpsc::Receiver<String>,
exit_tx: mpsc::Sender<()>,
format: Format,
map_line: F,
) where
T: DataStream,
U: Codec,
F: Fn(&str) -> io::Result<Request>,
{
let mut buf = StringBuf::new();
let mut mailbox_exits = Vec::new();
while let Some(data) = stdin_rx.recv().await {
// Update our buffer with the new data and split it into concrete lines and remainder
@ -112,21 +107,31 @@ async fn process_outgoing_requests<T, U, F>(
trace!("Processing line: {:?}", line);
if line.is_empty() {
continue;
} else if line == "exit" {
debug!("Got exit request, so closing cli session");
stdin_rx.close();
if exit_tx.send(()).await.is_err() {
error!("Failed to close cli session");
}
continue;
}
match map_line(line) {
Ok(req) => match session.send(req).await {
Ok(res) => match ResponseOut::new(format, res) {
Ok(out) => out.print(),
Err(x) => error!("Failed to format response: {}", x),
},
Ok(req) => match session.mail(req).await {
Ok(mut mailbox) => {
// Wait to get our first response before moving on to the next line
// of input
if let Some(res) = mailbox.next().await {
// Convert to response to output, and when successful launch
// a handler for continued responses to the same request
// such as with processes
match ResponseOut::new(format, res) {
Ok(out) => {
out.print();
let (tx, rx) = oneshot::channel();
mailbox_exits.push(tx);
tokio::spawn(process_mailbox(mailbox, format, rx));
}
Err(x) => {
error!("{}", x);
}
}
}
}
Err(x) => {
error!("Failed to send request: {}", x)
}
@ -138,4 +143,9 @@ async fn process_outgoing_requests<T, U, F>(
}
}
}
// Close out any dangling mailbox handlers
for tx in mailbox_exits {
let _ = tx.send(());
}
}

@ -8,8 +8,7 @@ use crate::{
};
use derive_more::{Display, Error, From};
use distant_core::{
Codec, DataStream, LspData, RemoteProcess, RemoteProcessError, Request, RequestData, Session,
TransportError,
LspData, RemoteProcess, RemoteProcessError, Request, RequestData, Session, TransportError,
};
use tokio::{io, time::Duration};
@ -66,23 +65,20 @@ async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
)
}
async fn start<T, U>(
async fn start(
cmd: ActionSubcommand,
mut session: Session<T, U>,
mut session: Session,
timeout: Duration,
lsp_data: Option<LspData>,
) -> Result<(), Error>
where
T: DataStream + 'static,
U: Codec + Send + 'static,
{
) -> Result<(), Error> {
let is_shell_format = matches!(cmd.format, Format::Shell);
match (cmd.interactive, cmd.operation) {
// ProcRun request w/ shell format is specially handled and we ignore interactive as
// the stdin will be used for sending ProcStdin to remote process
(_, Some(RequestData::ProcRun { cmd, args })) if is_shell_format => {
let mut proc = RemoteProcess::spawn(utils::new_tenant(), session, cmd, args).await?;
let mut proc =
RemoteProcess::spawn(utils::new_tenant(), &mut session, cmd, args).await?;
// If we also parsed an LSP's initialize request for its session, we want to forward
// it along in the case of a process call
@ -97,6 +93,12 @@ where
proc.stderr.take().unwrap(),
);
// Drop main session as the singular remote process will now manage stdin/stdout/stderr
// NOTE: Without this, severing stdin when from this side would not occur as we would
// continue to maintain a second reference to the remote connection's input
// through the primary session
drop(session);
let (success, exit_code) = proc.wait().await?;
// Shut down our link
@ -145,7 +147,7 @@ where
// Enter into CLI session where we receive requests on stdin and send out
// over stdout/stderr
let cli_session = CliSession::new(utils::new_tenant(), session, cmd.format);
let cli_session = CliSession::new_for_stdin(utils::new_tenant(), session, cmd.format);
cli_session.wait().await?;
Ok(())

@ -129,7 +129,7 @@ async fn keep_loop(info: SessionInfo, format: Format, duration: Duration) -> io:
let codec = XChaCha20Poly1305Codec::from(info.key);
match Session::tcp_connect_timeout(addr, codec, duration).await {
Ok(session) => {
let cli_session = CliSession::new(utils::new_tenant(), session, format);
let cli_session = CliSession::new_for_stdin(utils::new_tenant(), session, format);
cli_session.wait().await
}
Err(x) => Err(x),

@ -5,7 +5,7 @@ use crate::{
utils,
};
use derive_more::{Display, Error, From};
use distant_core::{Codec, DataStream, LspData, RemoteLspProcess, RemoteProcessError, Session};
use distant_core::{LspData, RemoteLspProcess, RemoteProcessError, Session};
use tokio::io;
#[derive(Debug, Display, Error, From)]
@ -53,16 +53,13 @@ async fn run_async(cmd: LspSubcommand, opt: CommonOpt) -> Result<(), Error> {
)
}
async fn start<T, U>(
async fn start(
cmd: LspSubcommand,
session: Session<T, U>,
mut session: Session,
lsp_data: Option<LspData>,
) -> Result<(), Error>
where
T: DataStream + 'static,
U: Codec + Send + 'static,
{
let mut proc = RemoteLspProcess::spawn(utils::new_tenant(), session, cmd.cmd, cmd.args).await?;
) -> Result<(), Error> {
let mut proc =
RemoteLspProcess::spawn(utils::new_tenant(), &mut session, cmd.cmd, cmd.args).await?;
// If we also parsed an LSP's initialize request for its session, we want to forward
// it along in the case of a process call

@ -327,7 +327,6 @@ fn should_support_json_to_forward_stdin_to_remote_process(ctx: &'_ DistantServer
let mut child = distant_subcommand(ctx, "action")
.args(&["--format", "json"])
.arg("--interactive")
.args(&["--log-file", "/tmp/test.log", "-vvv"])
.spawn()
.unwrap();

Loading…
Cancel
Save