diff --git a/CHANGELOG.md b/CHANGELOG.md index 77bf3ea..f0ebade 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `distant-local` implementation sending a separate `Changed` event per path - `ChangeDetails` now includes a `renamed` field to capture the new path name when known +- `DistantApi` now handles batch requests in parallel, returning the results in + order. To achieve the previous sequential processing of batch requests, the + header value `sequence` needs to be set to true ## [0.20.0-alpha.8] diff --git a/distant-core/src/api.rs b/distant-core/src/api.rs index 41ef0f5..1c15e6c 100644 --- a/distant-core/src/api.rs +++ b/distant-core/src/api.rs @@ -27,7 +27,7 @@ pub struct DistantApiServerHandler where T: DistantApi, { - api: T, + api: Arc, } impl DistantApiServerHandler @@ -35,7 +35,7 @@ where T: DistantApi, { pub fn new(api: T) -> Self { - Self { api } + Self { api: Arc::new(api) } } } @@ -424,8 +424,8 @@ pub trait DistantApi { #[async_trait] impl ServerHandler for DistantApiServerHandler where - T: DistantApi + Send + Sync, - D: Send + Sync, + T: DistantApi + Send + Sync + 'static, + D: Send + Sync + 'static, { type LocalData = D; type Request = protocol::Msg; @@ -457,7 +457,7 @@ where local_data, }; - let data = handle_request(self, ctx, data).await; + let data = handle_request(Arc::clone(&self.api), ctx, data).await; // Report outgoing errors in our debug logs if let protocol::Response::Error(x) = &data { @@ -466,27 +466,35 @@ where protocol::Msg::Single(data) } - protocol::Msg::Batch(list) => { + protocol::Msg::Batch(list) + if matches!(request.header.get_as("sequence"), Some(Ok(true))) => + { let mut out = Vec::new(); + let mut has_failed = false; for data in list { + // Once we hit a failure, all remaining requests return interrupted + if has_failed { + out.push(protocol::Response::Error(protocol::Error { + kind: protocol::ErrorKind::Interrupted, + description: String::from("Canceled due to earlier error"), + })); + continue; + } + let ctx = DistantCtx { connection_id, reply: Box::new(DistantSingleReply::from(reply.clone_reply())), local_data: Arc::clone(&local_data), }; - // TODO: This does not run in parallel, meaning that the next item in the - // batch will not be queued until the previous item completes! This - // would be useful if we wanted to chain requests where the previous - // request feeds into the current request, but not if we just want - // to run everything together. So we should instead rewrite this - // to spawn a task per request and then await completion of all tasks - let data = handle_request(self, ctx, data).await; + let data = handle_request(Arc::clone(&self.api), ctx, data).await; - // Report outgoing errors in our debug logs + // Report outgoing errors in our debug logs and mark as failed + // to cancel any future tasks being run if let protocol::Response::Error(x) = &data { debug!("[Conn {}] {}", connection_id, x); + has_failed = true; } out.push(data); @@ -494,6 +502,44 @@ where protocol::Msg::Batch(out) } + protocol::Msg::Batch(list) => { + let mut tasks = Vec::new(); + + // If sequence specified as true, we want to process in order, otherwise we can + // process in any order + + for data in list { + let api = Arc::clone(&self.api); + let ctx = DistantCtx { + connection_id, + reply: Box::new(DistantSingleReply::from(reply.clone_reply())), + local_data: Arc::clone(&local_data), + }; + + let task = tokio::spawn(async move { + let data = handle_request(api, ctx, data).await; + + // Report outgoing errors in our debug logs + if let protocol::Response::Error(x) = &data { + debug!("[Conn {}] {}", connection_id, x); + } + + data + }); + + tasks.push(task); + } + + let out = futures::future::join_all(tasks) + .await + .into_iter() + .map(|x| match x { + Ok(x) => x, + Err(x) => protocol::Response::Error(x.to_string().into()), + }) + .collect(); + protocol::Msg::Batch(out) + } }; // Queue up our result to go before ANY of the other messages that might be sent. @@ -515,7 +561,7 @@ where /// Processes an incoming request async fn handle_request( - server: &DistantApiServerHandler, + api: Arc, ctx: DistantCtx, request: protocol::Request, ) -> protocol::Response @@ -524,44 +570,37 @@ where D: Send + Sync, { match request { - protocol::Request::Version {} => server - .api + protocol::Request::Version {} => api .version(ctx) .await .map(protocol::Response::Version) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileRead { path } => server - .api + protocol::Request::FileRead { path } => api .read_file(ctx, path) .await .map(|data| protocol::Response::Blob { data }) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileReadText { path } => server - .api + protocol::Request::FileReadText { path } => api .read_file_text(ctx, path) .await .map(|data| protocol::Response::Text { data }) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileWrite { path, data } => server - .api + protocol::Request::FileWrite { path, data } => api .write_file(ctx, path, data) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileWriteText { path, text } => server - .api + protocol::Request::FileWriteText { path, text } => api .write_file_text(ctx, path, text) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileAppend { path, data } => server - .api + protocol::Request::FileAppend { path, data } => api .append_file(ctx, path, data) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileAppendText { path, text } => server - .api + protocol::Request::FileAppendText { path, text } => api .append_file_text(ctx, path, text) .await .map(|_| protocol::Response::Ok) @@ -572,8 +611,7 @@ where absolute, canonicalize, include_root, - } => server - .api + } => api .read_dir(ctx, path, depth, absolute, canonicalize, include_root) .await .map(|(entries, errors)| protocol::Response::DirEntries { @@ -581,26 +619,22 @@ where errors: errors.into_iter().map(Error::from).collect(), }) .unwrap_or_else(protocol::Response::from), - protocol::Request::DirCreate { path, all } => server - .api + protocol::Request::DirCreate { path, all } => api .create_dir(ctx, path, all) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Remove { path, force } => server - .api + protocol::Request::Remove { path, force } => api .remove(ctx, path, force) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Copy { src, dst } => server - .api + protocol::Request::Copy { src, dst } => api .copy(ctx, src, dst) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Rename { src, dst } => server - .api + protocol::Request::Rename { src, dst } => api .rename(ctx, src, dst) .await .map(|_| protocol::Response::Ok) @@ -610,20 +644,17 @@ where recursive, only, except, - } => server - .api + } => api .watch(ctx, path, recursive, only, except) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Unwatch { path } => server - .api + protocol::Request::Unwatch { path } => api .unwatch(ctx, path) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Exists { path } => server - .api + protocol::Request::Exists { path } => api .exists(ctx, path) .await .map(|value| protocol::Response::Exists { value }) @@ -632,8 +663,7 @@ where path, canonicalize, resolve_file_type, - } => server - .api + } => api .metadata(ctx, path, canonicalize, resolve_file_type) .await .map(protocol::Response::Metadata) @@ -642,20 +672,17 @@ where path, permissions, options, - } => server - .api + } => api .set_permissions(ctx, path, permissions, options) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Search { query } => server - .api + protocol::Request::Search { query } => api .search(ctx, query) .await .map(|id| protocol::Response::SearchStarted { id }) .unwrap_or_else(protocol::Response::from), - protocol::Request::CancelSearch { id } => server - .api + protocol::Request::CancelSearch { id } => api .cancel_search(ctx, id) .await .map(|_| protocol::Response::Ok) @@ -665,32 +692,27 @@ where environment, current_dir, pty, - } => server - .api + } => api .proc_spawn(ctx, cmd.into(), environment, current_dir, pty) .await .map(|id| protocol::Response::ProcSpawned { id }) .unwrap_or_else(protocol::Response::from), - protocol::Request::ProcKill { id } => server - .api + protocol::Request::ProcKill { id } => api .proc_kill(ctx, id) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::ProcStdin { id, data } => server - .api + protocol::Request::ProcStdin { id, data } => api .proc_stdin(ctx, id, data) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::ProcResizePty { id, size } => server - .api + protocol::Request::ProcResizePty { id, size } => api .proc_resize_pty(ctx, id, size) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::SystemInfo {} => server - .api + protocol::Request::SystemInfo {} => api .system_info(ctx) .await .map(protocol::Response::SystemInfo) diff --git a/distant-core/tests/api_tests.rs b/distant-core/tests/api_tests.rs new file mode 100644 index 0000000..a1dbcda --- /dev/null +++ b/distant-core/tests/api_tests.rs @@ -0,0 +1,347 @@ +use std::io; +use std::path::PathBuf; + +use async_trait::async_trait; +use distant_core::{ + DistantApi, DistantApiServerHandler, DistantChannelExt, DistantClient, DistantCtx, +}; +use distant_net::auth::{DummyAuthHandler, Verifier}; +use distant_net::client::Client; +use distant_net::common::{InmemoryTransport, OneshotListener}; +use distant_net::server::{Server, ServerRef}; + +/// Stands up an inmemory client and server using the given api. +async fn setup( + api: impl DistantApi + Send + Sync + 'static, +) -> (DistantClient, Box) { + let (t1, t2) = InmemoryTransport::pair(100); + + let server = Server::new() + .handler(DistantApiServerHandler::new(api)) + .verifier(Verifier::none()) + .start(OneshotListener::from_value(t2)) + .expect("Failed to start server"); + + let client: DistantClient = Client::build() + .auth_handler(DummyAuthHandler) + .connector(t1) + .connect() + .await + .expect("Failed to connect to server"); + + (client, server) +} + +mod single { + use super::*; + use test_log::test; + + #[test(tokio::test)] + async fn should_support_single_request_returning_error() { + struct TestDistantApi; + + #[async_trait] + impl DistantApi for TestDistantApi { + type LocalData = (); + + async fn read_file( + &self, + _ctx: DistantCtx, + _path: PathBuf, + ) -> io::Result> { + Err(io::Error::new(io::ErrorKind::NotFound, "test error")) + } + } + + let (mut client, _server) = setup(TestDistantApi).await; + + let error = client.read_file(PathBuf::from("file")).await.unwrap_err(); + assert_eq!(error.kind(), io::ErrorKind::NotFound); + assert_eq!(error.to_string(), "test error"); + } + + #[test(tokio::test)] + async fn should_support_single_request_returning_success() { + struct TestDistantApi; + + #[async_trait] + impl DistantApi for TestDistantApi { + type LocalData = (); + + async fn read_file( + &self, + _ctx: DistantCtx, + _path: PathBuf, + ) -> io::Result> { + Ok(b"hello world".to_vec()) + } + } + + let (mut client, _server) = setup(TestDistantApi).await; + + let contents = client.read_file(PathBuf::from("file")).await.unwrap(); + assert_eq!(contents, b"hello world"); + } +} + +mod batch_parallel { + use super::*; + use distant_net::common::Request; + use distant_protocol::{Msg, Request as RequestPayload}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + use test_log::test; + + #[test(tokio::test)] + async fn should_support_multiple_requests_running_in_parallel() { + struct TestDistantApi; + + #[async_trait] + impl DistantApi for TestDistantApi { + type LocalData = (); + + async fn read_file( + &self, + _ctx: DistantCtx, + path: PathBuf, + ) -> io::Result> { + if path.to_str().unwrap() == "slow" { + tokio::time::sleep(Duration::from_millis(500)).await; + } + + let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + Ok((time.as_millis() as u64).to_be_bytes().to_vec()) + } + } + + let (mut client, _server) = setup(TestDistantApi).await; + + let request = Request::new(Msg::batch([ + RequestPayload::FileRead { + path: PathBuf::from("file1"), + }, + RequestPayload::FileRead { + path: PathBuf::from("slow"), + }, + RequestPayload::FileRead { + path: PathBuf::from("file2"), + }, + ])); + + let response = client.send(request).await.unwrap(); + let payloads = response.payload.into_batch().unwrap(); + + // Collect our times from the reading + let mut times = Vec::new(); + for payload in payloads { + match payload { + distant_protocol::Response::Blob { data } => { + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[..8]); + times.push(u64::from_be_bytes(buf)); + } + x => panic!("Unexpected payload: {x:?}"), + } + } + + // Verify that these ran in parallel as the first and third requests should not be + // over 500 milliseconds apart due to the sleep in the middle! + let diff = times[0].abs_diff(times[2]); + assert!(diff <= 500, "Sequential ordering detected"); + } + + #[test(tokio::test)] + async fn should_run_all_requests_even_if_some_fail() { + struct TestDistantApi; + + #[async_trait] + impl DistantApi for TestDistantApi { + type LocalData = (); + + async fn read_file( + &self, + _ctx: DistantCtx, + path: PathBuf, + ) -> io::Result> { + if path.to_str().unwrap() == "fail" { + return Err(io::Error::new(io::ErrorKind::Other, "test error")); + } + + Ok(Vec::new()) + } + } + + let (mut client, _server) = setup(TestDistantApi).await; + + let request = Request::new(Msg::batch([ + RequestPayload::FileRead { + path: PathBuf::from("file1"), + }, + RequestPayload::FileRead { + path: PathBuf::from("fail"), + }, + RequestPayload::FileRead { + path: PathBuf::from("file2"), + }, + ])); + + let response = client.send(request).await.unwrap(); + let payloads = response.payload.into_batch().unwrap(); + + // Should be a success, error, and success + assert!( + matches!(payloads[0], distant_protocol::Response::Blob { .. }), + "Unexpected payloads[0]: {:?}", + payloads[0] + ); + assert!( + matches!( + &payloads[1], + distant_protocol::Response::Error(distant_protocol::Error { kind, description }) + if matches!(kind, distant_protocol::ErrorKind::Other) && description == "test error" + ), + "Unexpected payloads[1]: {:?}", + payloads[1] + ); + assert!( + matches!(payloads[2], distant_protocol::Response::Blob { .. }), + "Unexpected payloads[2]: {:?}", + payloads[2] + ); + } +} + +mod batch_sequence { + use super::*; + use distant_net::common::Request; + use distant_protocol::{Msg, Request as RequestPayload}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + use test_log::test; + + #[test(tokio::test)] + async fn should_support_multiple_requests_running_in_sequence() { + struct TestDistantApi; + + #[async_trait] + impl DistantApi for TestDistantApi { + type LocalData = (); + + async fn read_file( + &self, + _ctx: DistantCtx, + path: PathBuf, + ) -> io::Result> { + if path.to_str().unwrap() == "slow" { + tokio::time::sleep(Duration::from_millis(500)).await; + } + + let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + Ok((time.as_millis() as u64).to_be_bytes().to_vec()) + } + } + + let (mut client, _server) = setup(TestDistantApi).await; + + let mut request = Request::new(Msg::batch([ + RequestPayload::FileRead { + path: PathBuf::from("file1"), + }, + RequestPayload::FileRead { + path: PathBuf::from("slow"), + }, + RequestPayload::FileRead { + path: PathBuf::from("file2"), + }, + ])); + + // Mark as running in sequence + request.header.insert("sequence", true); + + let response = client.send(request).await.unwrap(); + let payloads = response.payload.into_batch().unwrap(); + + // Collect our times from the reading + let mut times = Vec::new(); + for payload in payloads { + match payload { + distant_protocol::Response::Blob { data } => { + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[..8]); + times.push(u64::from_be_bytes(buf)); + } + x => panic!("Unexpected payload: {x:?}"), + } + } + + // Verify that these ran in sequence as the first and third requests should be + // over 500 milliseconds apart due to the sleep in the middle! + let diff = times[0].abs_diff(times[2]); + assert!(diff > 500, "Parallel ordering detected"); + } + + #[test(tokio::test)] + async fn should_interrupt_any_requests_following_a_failure() { + struct TestDistantApi; + + #[async_trait] + impl DistantApi for TestDistantApi { + type LocalData = (); + + async fn read_file( + &self, + _ctx: DistantCtx, + path: PathBuf, + ) -> io::Result> { + if path.to_str().unwrap() == "fail" { + return Err(io::Error::new(io::ErrorKind::Other, "test error")); + } + + Ok(Vec::new()) + } + } + + let (mut client, _server) = setup(TestDistantApi).await; + + let mut request = Request::new(Msg::batch([ + RequestPayload::FileRead { + path: PathBuf::from("file1"), + }, + RequestPayload::FileRead { + path: PathBuf::from("fail"), + }, + RequestPayload::FileRead { + path: PathBuf::from("file2"), + }, + ])); + + // Mark as running in sequence + request.header.insert("sequence", true); + + let response = client.send(request).await.unwrap(); + let payloads = response.payload.into_batch().unwrap(); + + // Should be a success, error, and interrupt + assert!( + matches!(payloads[0], distant_protocol::Response::Blob { .. }), + "Unexpected payloads[0]: {:?}", + payloads[0] + ); + assert!( + matches!( + &payloads[1], + distant_protocol::Response::Error(distant_protocol::Error { kind, description }) + if matches!(kind, distant_protocol::ErrorKind::Other) && description == "test error" + ), + "Unexpected payloads[1]: {:?}", + payloads[1] + ); + assert!( + matches!( + &payloads[2], + distant_protocol::Response::Error(distant_protocol::Error { kind, .. }) + if matches!(kind, distant_protocol::ErrorKind::Interrupted) + ), + "Unexpected payloads[2]: {:?}", + payloads[2] + ); + } +} diff --git a/distant-net/src/common.rs b/distant-net/src/common.rs index a0e79dc..5f793c8 100644 --- a/distant-net/src/common.rs +++ b/distant-net/src/common.rs @@ -20,5 +20,4 @@ pub use listener::*; pub use map::*; pub use packet::*; pub use port::*; -pub use serde_json::Value; pub use transport::*; diff --git a/distant-net/src/common/packet.rs b/distant-net/src/common/packet.rs index c55fe2d..14c97be 100644 --- a/distant-net/src/common/packet.rs +++ b/distant-net/src/common/packet.rs @@ -1,10 +1,12 @@ mod header; mod request; mod response; +mod value; pub use header::*; pub use request::*; pub use response::*; +pub use value::*; use std::io::Cursor; diff --git a/distant-net/src/common/packet/header.rs b/distant-net/src/common/packet/header.rs index 93425f4..6712ea6 100644 --- a/distant-net/src/common/packet/header.rs +++ b/distant-net/src/common/packet/header.rs @@ -1,5 +1,6 @@ use crate::common::{utils, Value}; use derive_more::IntoIterator; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::io; @@ -54,6 +55,17 @@ impl Header { self.0.insert(key.into(), value.into()) } + /// Retrieves a value from the header, attempting to convert it to the specified type `T` + /// by cloning the value and then converting it. + pub fn get_as(&self, key: impl AsRef) -> Option> + where + T: DeserializeOwned, + { + self.0 + .get(key.as_ref()) + .map(|value| value.clone().cast_as()) + } + /// Serializes the header into bytes. pub fn to_vec(&self) -> io::Result> { utils::serialize_to_vec(self) diff --git a/distant-net/src/common/packet/value.rs b/distant-net/src/common/packet/value.rs new file mode 100644 index 0000000..b490786 --- /dev/null +++ b/distant-net/src/common/packet/value.rs @@ -0,0 +1,110 @@ +use crate::common::utils; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::io; +use std::ops::{Deref, DerefMut}; + +/// Generic value type for data passed through header. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Value(serde_json::Value); + +impl Value { + /// Creates a new [`Value`] by converting `value` to the underlying type. + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } + + /// Serializes the value into bytes. + pub fn to_vec(&self) -> io::Result> { + utils::serialize_to_vec(self) + } + + /// Deserializes the value from bytes. + pub fn from_slice(slice: &[u8]) -> io::Result { + utils::deserialize_from_slice(slice) + } + + /// Attempts to convert this generic value to a specific type. + pub fn cast_as(self) -> io::Result + where + T: DeserializeOwned, + { + serde_json::from_value(self.0).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)) + } +} + +impl Deref for Value { + type Target = serde_json::Value; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Value { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +macro_rules! impl_from { + ($($type:ty),+) => { + $( + impl From<$type> for Value { + fn from(x: $type) -> Self { + Self(From::from(x)) + } + } + )+ + }; +} + +impl_from!( + (), + i8, i16, i32, i64, isize, + u8, u16, u32, u64, usize, + f32, f64, + bool, String, serde_json::Number, + serde_json::Map +); + +impl<'a, T> From<&'a [T]> for Value +where + T: Clone + Into, +{ + fn from(x: &'a [T]) -> Self { + Self(From::from(x)) + } +} + +impl<'a> From<&'a str> for Value { + fn from(x: &'a str) -> Self { + Self(From::from(x)) + } +} + +impl<'a> From> for Value { + fn from(x: Cow<'a, str>) -> Self { + Self(From::from(x)) + } +} + +impl From> for Value +where + T: Into, +{ + fn from(x: Option) -> Self { + Self(From::from(x)) + } +} + +impl From> for Value +where + T: Into, +{ + fn from(x: Vec) -> Self { + Self(From::from(x)) + } +}