diff --git a/CHANGELOG.md b/CHANGELOG.md index 77bf3ea..f0ebade 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `distant-local` implementation sending a separate `Changed` event per path - `ChangeDetails` now includes a `renamed` field to capture the new path name when known +- `DistantApi` now handles batch requests in parallel, returning the results in + order. To achieve the previous sequential processing of batch requests, the + header value `sequence` needs to be set to true ## [0.20.0-alpha.8] diff --git a/distant-core/src/api.rs b/distant-core/src/api.rs index 41ef0f5..3a473ef 100644 --- a/distant-core/src/api.rs +++ b/distant-core/src/api.rs @@ -27,7 +27,7 @@ pub struct DistantApiServerHandler where T: DistantApi, { - api: T, + api: Arc, } impl DistantApiServerHandler @@ -35,7 +35,7 @@ where T: DistantApi, { pub fn new(api: T) -> Self { - Self { api } + Self { api: Arc::new(api) } } } @@ -424,8 +424,8 @@ pub trait DistantApi { #[async_trait] impl ServerHandler for DistantApiServerHandler where - T: DistantApi + Send + Sync, - D: Send + Sync, + T: DistantApi + Send + Sync + 'static, + D: Send + Sync + 'static, { type LocalData = D; type Request = protocol::Msg; @@ -457,7 +457,7 @@ where local_data, }; - let data = handle_request(self, ctx, data).await; + let data = handle_request(Arc::clone(&self.api), ctx, data).await; // Report outgoing errors in our debug logs if let protocol::Response::Error(x) = &data { @@ -466,7 +466,9 @@ where protocol::Msg::Single(data) } - protocol::Msg::Batch(list) => { + protocol::Msg::Batch(list) + if matches!(request.header.get_as("sequence"), Some(Ok(true))) => + { let mut out = Vec::new(); for data in list { @@ -476,13 +478,7 @@ where local_data: Arc::clone(&local_data), }; - // TODO: This does not run in parallel, meaning that the next item in the - // batch will not be queued until the previous item completes! This - // would be useful if we wanted to chain requests where the previous - // request feeds into the current request, but not if we just want - // to run everything together. So we should instead rewrite this - // to spawn a task per request and then await completion of all tasks - let data = handle_request(self, ctx, data).await; + let data = handle_request(Arc::clone(&self.api), ctx, data).await; // Report outgoing errors in our debug logs if let protocol::Response::Error(x) = &data { @@ -494,6 +490,44 @@ where protocol::Msg::Batch(out) } + protocol::Msg::Batch(list) => { + let mut tasks = Vec::new(); + + // If sequence specified as true, we want to process in order, otherwise we can + // process in any order + + for data in list { + let api = Arc::clone(&self.api); + let ctx = DistantCtx { + connection_id, + reply: Box::new(DistantSingleReply::from(reply.clone_reply())), + local_data: Arc::clone(&local_data), + }; + + let task = tokio::spawn(async move { + let data = handle_request(api, ctx, data).await; + + // Report outgoing errors in our debug logs + if let protocol::Response::Error(x) = &data { + debug!("[Conn {}] {}", connection_id, x); + } + + data + }); + + tasks.push(task); + } + + let out = futures::future::join_all(tasks) + .await + .into_iter() + .map(|x| match x { + Ok(x) => x, + Err(x) => protocol::Response::Error(x.to_string().into()), + }) + .collect(); + protocol::Msg::Batch(out) + } }; // Queue up our result to go before ANY of the other messages that might be sent. @@ -515,7 +549,7 @@ where /// Processes an incoming request async fn handle_request( - server: &DistantApiServerHandler, + api: Arc, ctx: DistantCtx, request: protocol::Request, ) -> protocol::Response @@ -524,44 +558,37 @@ where D: Send + Sync, { match request { - protocol::Request::Version {} => server - .api + protocol::Request::Version {} => api .version(ctx) .await .map(protocol::Response::Version) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileRead { path } => server - .api + protocol::Request::FileRead { path } => api .read_file(ctx, path) .await .map(|data| protocol::Response::Blob { data }) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileReadText { path } => server - .api + protocol::Request::FileReadText { path } => api .read_file_text(ctx, path) .await .map(|data| protocol::Response::Text { data }) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileWrite { path, data } => server - .api + protocol::Request::FileWrite { path, data } => api .write_file(ctx, path, data) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileWriteText { path, text } => server - .api + protocol::Request::FileWriteText { path, text } => api .write_file_text(ctx, path, text) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileAppend { path, data } => server - .api + protocol::Request::FileAppend { path, data } => api .append_file(ctx, path, data) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::FileAppendText { path, text } => server - .api + protocol::Request::FileAppendText { path, text } => api .append_file_text(ctx, path, text) .await .map(|_| protocol::Response::Ok) @@ -572,8 +599,7 @@ where absolute, canonicalize, include_root, - } => server - .api + } => api .read_dir(ctx, path, depth, absolute, canonicalize, include_root) .await .map(|(entries, errors)| protocol::Response::DirEntries { @@ -581,26 +607,22 @@ where errors: errors.into_iter().map(Error::from).collect(), }) .unwrap_or_else(protocol::Response::from), - protocol::Request::DirCreate { path, all } => server - .api + protocol::Request::DirCreate { path, all } => api .create_dir(ctx, path, all) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Remove { path, force } => server - .api + protocol::Request::Remove { path, force } => api .remove(ctx, path, force) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Copy { src, dst } => server - .api + protocol::Request::Copy { src, dst } => api .copy(ctx, src, dst) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Rename { src, dst } => server - .api + protocol::Request::Rename { src, dst } => api .rename(ctx, src, dst) .await .map(|_| protocol::Response::Ok) @@ -610,20 +632,17 @@ where recursive, only, except, - } => server - .api + } => api .watch(ctx, path, recursive, only, except) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Unwatch { path } => server - .api + protocol::Request::Unwatch { path } => api .unwatch(ctx, path) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Exists { path } => server - .api + protocol::Request::Exists { path } => api .exists(ctx, path) .await .map(|value| protocol::Response::Exists { value }) @@ -632,8 +651,7 @@ where path, canonicalize, resolve_file_type, - } => server - .api + } => api .metadata(ctx, path, canonicalize, resolve_file_type) .await .map(protocol::Response::Metadata) @@ -642,20 +660,17 @@ where path, permissions, options, - } => server - .api + } => api .set_permissions(ctx, path, permissions, options) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::Search { query } => server - .api + protocol::Request::Search { query } => api .search(ctx, query) .await .map(|id| protocol::Response::SearchStarted { id }) .unwrap_or_else(protocol::Response::from), - protocol::Request::CancelSearch { id } => server - .api + protocol::Request::CancelSearch { id } => api .cancel_search(ctx, id) .await .map(|_| protocol::Response::Ok) @@ -665,32 +680,27 @@ where environment, current_dir, pty, - } => server - .api + } => api .proc_spawn(ctx, cmd.into(), environment, current_dir, pty) .await .map(|id| protocol::Response::ProcSpawned { id }) .unwrap_or_else(protocol::Response::from), - protocol::Request::ProcKill { id } => server - .api + protocol::Request::ProcKill { id } => api .proc_kill(ctx, id) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::ProcStdin { id, data } => server - .api + protocol::Request::ProcStdin { id, data } => api .proc_stdin(ctx, id, data) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::ProcResizePty { id, size } => server - .api + protocol::Request::ProcResizePty { id, size } => api .proc_resize_pty(ctx, id, size) .await .map(|_| protocol::Response::Ok) .unwrap_or_else(protocol::Response::from), - protocol::Request::SystemInfo {} => server - .api + protocol::Request::SystemInfo {} => api .system_info(ctx) .await .map(protocol::Response::SystemInfo) diff --git a/distant-net/src/common.rs b/distant-net/src/common.rs index a0e79dc..5f793c8 100644 --- a/distant-net/src/common.rs +++ b/distant-net/src/common.rs @@ -20,5 +20,4 @@ pub use listener::*; pub use map::*; pub use packet::*; pub use port::*; -pub use serde_json::Value; pub use transport::*; diff --git a/distant-net/src/common/packet.rs b/distant-net/src/common/packet.rs index c55fe2d..14c97be 100644 --- a/distant-net/src/common/packet.rs +++ b/distant-net/src/common/packet.rs @@ -1,10 +1,12 @@ mod header; mod request; mod response; +mod value; pub use header::*; pub use request::*; pub use response::*; +pub use value::*; use std::io::Cursor; diff --git a/distant-net/src/common/packet/header.rs b/distant-net/src/common/packet/header.rs index 93425f4..6712ea6 100644 --- a/distant-net/src/common/packet/header.rs +++ b/distant-net/src/common/packet/header.rs @@ -1,5 +1,6 @@ use crate::common::{utils, Value}; use derive_more::IntoIterator; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::io; @@ -54,6 +55,17 @@ impl Header { self.0.insert(key.into(), value.into()) } + /// Retrieves a value from the header, attempting to convert it to the specified type `T` + /// by cloning the value and then converting it. + pub fn get_as(&self, key: impl AsRef) -> Option> + where + T: DeserializeOwned, + { + self.0 + .get(key.as_ref()) + .map(|value| value.clone().cast_as()) + } + /// Serializes the header into bytes. pub fn to_vec(&self) -> io::Result> { utils::serialize_to_vec(self) diff --git a/distant-net/src/common/packet/value.rs b/distant-net/src/common/packet/value.rs new file mode 100644 index 0000000..b490786 --- /dev/null +++ b/distant-net/src/common/packet/value.rs @@ -0,0 +1,110 @@ +use crate::common::utils; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::io; +use std::ops::{Deref, DerefMut}; + +/// Generic value type for data passed through header. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Value(serde_json::Value); + +impl Value { + /// Creates a new [`Value`] by converting `value` to the underlying type. + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } + + /// Serializes the value into bytes. + pub fn to_vec(&self) -> io::Result> { + utils::serialize_to_vec(self) + } + + /// Deserializes the value from bytes. + pub fn from_slice(slice: &[u8]) -> io::Result { + utils::deserialize_from_slice(slice) + } + + /// Attempts to convert this generic value to a specific type. + pub fn cast_as(self) -> io::Result + where + T: DeserializeOwned, + { + serde_json::from_value(self.0).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x)) + } +} + +impl Deref for Value { + type Target = serde_json::Value; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Value { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +macro_rules! impl_from { + ($($type:ty),+) => { + $( + impl From<$type> for Value { + fn from(x: $type) -> Self { + Self(From::from(x)) + } + } + )+ + }; +} + +impl_from!( + (), + i8, i16, i32, i64, isize, + u8, u16, u32, u64, usize, + f32, f64, + bool, String, serde_json::Number, + serde_json::Map +); + +impl<'a, T> From<&'a [T]> for Value +where + T: Clone + Into, +{ + fn from(x: &'a [T]) -> Self { + Self(From::from(x)) + } +} + +impl<'a> From<&'a str> for Value { + fn from(x: &'a str) -> Self { + Self(From::from(x)) + } +} + +impl<'a> From> for Value { + fn from(x: Cow<'a, str>) -> Self { + Self(From::from(x)) + } +} + +impl From> for Value +where + T: Into, +{ + fn from(x: Option) -> Self { + Self(From::from(x)) + } +} + +impl From> for Value +where + T: Into, +{ + fn from(x: Vec) -> Self { + Self(From::from(x)) + } +}