Add status method to RemoteProcess and lua module equivalent

pull/96/head
Chip Senkbeil 3 years ago
parent f021869310
commit a8b6f3eb31
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

8
Cargo.lock generated

@ -427,7 +427,7 @@ dependencies = [
[[package]]
name = "distant"
version = "0.15.0-alpha.14"
version = "0.15.0-alpha.15"
dependencies = [
"assert_cmd",
"assert_fs",
@ -451,7 +451,7 @@ dependencies = [
[[package]]
name = "distant-core"
version = "0.15.0-alpha.14"
version = "0.15.0-alpha.15"
dependencies = [
"assert_fs",
"bytes",
@ -476,7 +476,7 @@ dependencies = [
[[package]]
name = "distant-lua"
version = "0.15.0-alpha.14"
version = "0.15.0-alpha.15"
dependencies = [
"distant-core",
"distant-ssh2",
@ -510,7 +510,7 @@ dependencies = [
[[package]]
name = "distant-ssh2"
version = "0.15.0-alpha.14"
version = "0.15.0-alpha.15"
dependencies = [
"assert_cmd",
"assert_fs",

@ -6,11 +6,15 @@ use crate::{
};
use derive_more::{Display, Error, From};
use log::*;
use std::sync::Arc;
use tokio::{
io,
sync::mpsc::{
self,
error::{TryRecvError, TrySendError},
sync::{
mpsc::{
self,
error::{TryRecvError, TrySendError},
},
RwLock,
},
task::{JoinError, JoinHandle},
};
@ -40,12 +44,11 @@ pub struct RemoteProcess {
/// Id used to map back to mailbox
pub(crate) origin_id: usize,
/// Task that forwards stdin to the remote process by bundling it as stdin requests
req_task: JoinHandle<Result<(), RemoteProcessError>>,
// Sender to abort req task
abort_req_task_tx: mpsc::Sender<()>,
/// Task that reads in new responses, which returns the success and optional
/// exit code once the process has completed
res_task: JoinHandle<Result<(bool, Option<i32>), RemoteProcessError>>,
// Sender to abort res task
abort_res_task_tx: mpsc::Sender<()>,
/// Sender for stdin
pub stdin: Option<RemoteStdin>,
@ -58,6 +61,12 @@ pub struct RemoteProcess {
/// Sender for kill events
kill: mpsc::Sender<()>,
/// Task that waits for the process to complete
wait_task: JoinHandle<()>,
/// Handles the success and exit code for a completed process
status: Arc<RwLock<Option<Result<(bool, Option<i32>), RemoteProcessError>>>>,
}
impl RemoteProcess {
@ -125,28 +134,56 @@ impl RemoteProcess {
// Used to terminate request task, either explicitly by the process or internally
// by the response task when it terminates
let (kill_tx, kill_rx) = mpsc::channel(1);
let kill_tx_2 = kill_tx.clone();
// Now we spawn a task to handle future responses that are async
// such as ProcStdout, ProcStderr, and ProcDone
let kill_tx_2 = kill_tx.clone();
let (abort_res_task_tx, mut abort_res_task_rx) = mpsc::channel::<()>(1);
let res_task = tokio::spawn(async move {
process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2).await
tokio::select! {
_ = abort_res_task_rx.recv() => {
panic!("killed");
}
res = process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2) => {
res
}
}
});
// Spawn a task that takes stdin from our channel and forwards it to the remote process
let (abort_req_task_tx, mut abort_req_task_rx) = mpsc::channel::<()>(1);
let req_task = tokio::spawn(async move {
process_outgoing_requests(tenant, id, channel, stdin_rx, kill_rx).await
tokio::select! {
_ = abort_req_task_rx.recv() => {
panic!("killed");
}
res = process_outgoing_requests(tenant, id, channel, stdin_rx, kill_rx) => {
res
}
}
});
let status = Arc::new(RwLock::new(None));
let status_2 = Arc::clone(&status);
let wait_task = tokio::spawn(async move {
let res = match tokio::try_join!(req_task, res_task) {
Ok((_, res)) => res,
Err(x) => Err(RemoteProcessError::from(x)),
};
status_2.write().await.replace(res);
});
Ok(Self {
id,
origin_id,
req_task,
res_task,
abort_req_task_tx,
abort_res_task_tx,
stdin: Some(RemoteStdin(stdin_tx)),
stdout: Some(RemoteStdout(stdout_rx)),
stderr: Some(RemoteStderr(stderr_rx)),
kill: kill_tx,
wait_task,
status,
})
}
@ -155,20 +192,36 @@ impl RemoteProcess {
self.id
}
/// Checks if the process has completed, returning the exit status if it has, without
/// consuming the process itself. Note that this does not include join errors that can
/// occur when aborting and instead converts any error to a status of false. To acquire
/// the actual error, you must call `wait`
pub async fn status(&self) -> Option<(bool, Option<i32>)> {
self.status.read().await.as_ref().map(|x| match x {
Ok((success, exit_code)) => (*success, *exit_code),
Err(_) => (false, None),
})
}
/// Waits for the process to terminate, returning the success status and an optional exit code
pub async fn wait(self) -> Result<(bool, Option<i32>), RemoteProcessError> {
match tokio::try_join!(self.req_task, self.res_task) {
Ok((_, res)) => res,
Err(x) => Err(RemoteProcessError::from(x)),
}
// Wait for the process to complete before we try to get the status
let _ = self.wait_task.await;
// NOTE: If we haven't received an exit status, this lines up with the UnexpectedEof error
self.status
.write()
.await
.take()
.unwrap_or_else(|| Err(RemoteProcessError::UnexpectedEof))
}
/// Aborts the process by forcing its response task to shutdown, which means that a call
/// to `wait` will return an error. Note that this does **not** send a kill request, so if
/// you want to be nice you should send the request before aborting.
pub fn abort(&self) {
self.req_task.abort();
self.res_task.abort();
let _ = self.abort_req_task_tx.try_send(());
let _ = self.abort_res_task_tx.try_send(());
}
/// Submits a kill request for the running process
@ -352,6 +405,7 @@ mod tests {
data::{Error, ErrorKind, Response},
net::{InmemoryStream, PlainCodec, Transport},
};
use std::time::Duration;
fn make_session() -> (Transport<InmemoryStream, PlainCodec>, Session) {
let (t1, t2) = Transport::make_pair();
@ -702,6 +756,145 @@ mod tests {
assert_eq!(out, "some err");
}
#[tokio::test]
async fn status_should_return_none_if_not_done() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
session.clone_channel(),
String::from("cmd"),
vec![String::from("arg")],
false,
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then check its status
let proc = spawn_task.await.unwrap().unwrap();
let result = proc.status().await;
assert_eq!(result, None, "Unexpectedly got proc status: {:?}", result);
}
#[tokio::test]
async fn status_should_return_false_for_success_if_internal_tasks_fail() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
session.clone_channel(),
String::from("cmd"),
vec![String::from("arg")],
false,
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then abort it to make internal tasks fail
let proc = spawn_task.await.unwrap().unwrap();
proc.abort();
// Wait a bit to ensure the other tasks abort
tokio::time::sleep(Duration::from_millis(100)).await;
// Peek at the status to confirm the result
let result = proc.status().await;
match result {
Some((false, None)) => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn status_should_return_process_status_when_done() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
session.clone_channel(),
String::from("cmd"),
vec![String::from("arg")],
false,
)
.await
});
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then spawn a task for it to complete
let proc = spawn_task.await.unwrap().unwrap();
// Send a process completion response to pass along exit status and conclude wait
transport
.send(Response::new(
"test-tenant",
req.id,
vec![ResponseData::ProcDone {
id,
success: true,
code: Some(123),
}],
))
.await
.unwrap();
// Wait a bit to ensure the status gets transmitted
tokio::time::sleep(Duration::from_millis(100)).await;
// Finally, verify that we complete and get the expected results
assert_eq!(proc.status().await, Some((true, Some(123))));
}
#[tokio::test]
async fn wait_should_return_error_if_internal_tasks_fail() {
let (mut transport, session) = make_session();

@ -155,6 +155,19 @@ macro_rules! impl_process {
})
}
fn status(id: usize) -> LuaResult<Option<Status>> {
runtime::block_on(Self::status_async(id))
}
async fn status_async(id: usize) -> LuaResult<Option<Status>> {
with_proc_async!($map_name, id, proc -> {
Ok(proc.status().await.map(|(success, exit_code)| Status {
success,
exit_code,
}))
})
}
fn wait(id: usize) -> LuaResult<(bool, Option<i32>)> {
runtime::block_on(Self::wait_async(id))
}
@ -238,6 +251,10 @@ macro_rules! impl_process {
methods.add_async_method("read_stderr_async", |_, this, ()| {
runtime::spawn(Self::read_stderr_async(this.id))
});
methods.add_method("status", |_, this, ()| Self::status(this.id));
methods.add_async_method("status_async", |_, this, ()| {
runtime::spawn(Self::status_async(this.id))
});
methods.add_method("wait", |_, this, ()| Self::wait(this.id));
methods.add_async_method("wait_async", |_, this, ()| {
runtime::spawn(Self::wait_async(this.id))
@ -256,6 +273,29 @@ macro_rules! impl_process {
};
}
/// Represents process status
#[derive(Clone, Debug)]
pub struct Status {
pub success: bool,
pub exit_code: Option<i32>,
}
impl UserData for Status {
fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {
fields.add_field_method_get("success", |_, this| Ok(this.success));
fields.add_field_method_get("exit_code", |_, this| Ok(this.exit_code));
}
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("to_tbl", |lua, this, ()| {
let tbl = lua.create_table()?;
tbl.set("success", this.success)?;
tbl.set("exit_code", this.exit_code)?;
Ok(tbl)
});
}
}
/// Represents process output
#[derive(Clone, Debug)]
pub struct Output {

Loading…
Cancel
Save