Fix server hangup (#206)

pull/207/head
Chip Senkbeil 11 months ago committed by GitHub
parent 8009cc9361
commit da75801639
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Request` and `Response` types from `distant-net` now support an optional - `Request` and `Response` types from `distant-net` now support an optional
`Header` to send miscellaneous information `Header` to send miscellaneous information
- New feature `tracing` provides https://github.com/tokio-rs/tracing support
as a new `--tracing` flag. Must be compiled with
`RUSTFLAGS="--cfg tokio_unstable"` to properly operate.
### Changed ### Changed
@ -21,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DistantApi` now handles batch requests in parallel, returning the results in - `DistantApi` now handles batch requests in parallel, returning the results in
order. To achieve the previous sequential processing of batch requests, the order. To achieve the previous sequential processing of batch requests, the
header value `sequence` needs to be set to true header value `sequence` needs to be set to true
- Rename `GenericServerRef` to `ServerRef` and remove `ServerRef` trait,
refactoring `TcpServerRef`, `UnixSocketServerRef`, and `WindowsPipeServerRef`
to use the struct instead of `Box<dyn ServerRef>`
## [0.20.0-alpha.8] ## [0.20.0-alpha.8]

@ -5,11 +5,13 @@ use async_trait::async_trait;
use crate::authenticator::Authenticator; use crate::authenticator::Authenticator;
use crate::methods::AuthenticationMethod; use crate::methods::AuthenticationMethod;
/// Authenticaton method for a static secret key /// Authenticaton method that skips authentication and approves anything.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct NoneAuthenticationMethod; pub struct NoneAuthenticationMethod;
impl NoneAuthenticationMethod { impl NoneAuthenticationMethod {
pub const ID: &str = "none";
#[inline] #[inline]
pub fn new() -> Self { pub fn new() -> Self {
Self Self
@ -26,7 +28,7 @@ impl Default for NoneAuthenticationMethod {
#[async_trait] #[async_trait]
impl AuthenticationMethod for NoneAuthenticationMethod { impl AuthenticationMethod for NoneAuthenticationMethod {
fn id(&self) -> &'static str { fn id(&self) -> &'static str {
"none" Self::ID
} }
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> { async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {

@ -14,6 +14,8 @@ pub struct StaticKeyAuthenticationMethod<T> {
} }
impl<T> StaticKeyAuthenticationMethod<T> { impl<T> StaticKeyAuthenticationMethod<T> {
pub const ID: &str = "static_key";
#[inline] #[inline]
pub fn new(key: T) -> Self { pub fn new(key: T) -> Self {
Self { key } Self { key }
@ -26,7 +28,7 @@ where
T: FromStr + PartialEq + Send + Sync, T: FromStr + PartialEq + Send + Sync,
{ {
fn id(&self) -> &'static str { fn id(&self) -> &'static str {
"static_key" Self::ID
} }
async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()> { async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()> {

@ -4,7 +4,7 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use distant_net::common::ConnectionId; use distant_net::common::ConnectionId;
use distant_net::server::{ConnectionCtx, Reply, ServerCtx, ServerHandler}; use distant_net::server::{Reply, RequestCtx, ServerHandler};
use log::*; use log::*;
use crate::protocol::{ use crate::protocol::{
@ -16,23 +16,22 @@ mod reply;
use reply::DistantSingleReply; use reply::DistantSingleReply;
/// Represents the context provided to the [`DistantApi`] for incoming requests /// Represents the context provided to the [`DistantApi`] for incoming requests
pub struct DistantCtx<T> { pub struct DistantCtx {
pub connection_id: ConnectionId, pub connection_id: ConnectionId,
pub reply: Box<dyn Reply<Data = protocol::Response>>, pub reply: Box<dyn Reply<Data = protocol::Response>>,
pub local_data: Arc<T>,
} }
/// Represents a [`ServerHandler`] that leverages an API compliant with `distant` /// Represents a [`ServerHandler`] that leverages an API compliant with `distant`
pub struct DistantApiServerHandler<T, D> pub struct DistantApiServerHandler<T>
where where
T: DistantApi<LocalData = D>, T: DistantApi,
{ {
api: Arc<T>, api: Arc<T>,
} }
impl<T, D> DistantApiServerHandler<T, D> impl<T> DistantApiServerHandler<T>
where where
T: DistantApi<LocalData = D>, T: DistantApi,
{ {
pub fn new(api: T) -> Self { pub fn new(api: T) -> Self {
Self { api: Arc::new(api) } Self { api: Arc::new(api) }
@ -51,12 +50,15 @@ fn unsupported<T>(label: &str) -> io::Result<T> {
/// which can be used to build other servers that are compatible with distant /// which can be used to build other servers that are compatible with distant
#[async_trait] #[async_trait]
pub trait DistantApi { pub trait DistantApi {
type LocalData: Send + Sync; /// Invoked whenever a new connection is established.
#[allow(unused_variables)]
async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
/// Invoked whenever a new connection is established, providing a mutable reference to the /// Invoked whenever an existing connection is dropped.
/// newly-created local data. This is a way to support modifying local data before it is used.
#[allow(unused_variables)] #[allow(unused_variables)]
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
Ok(()) Ok(())
} }
@ -64,7 +66,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn version(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Version> { async fn version(&self, ctx: DistantCtx) -> io::Result<Version> {
unsupported("version") unsupported("version")
} }
@ -74,11 +76,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn read_file( async fn read_file(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
unsupported("read_file") unsupported("read_file")
} }
@ -88,11 +86,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn read_file_text( async fn read_file_text(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<String> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<String> {
unsupported("read_file_text") unsupported("read_file_text")
} }
@ -103,12 +97,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn write_file( async fn write_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
unsupported("write_file") unsupported("write_file")
} }
@ -121,7 +110,7 @@ pub trait DistantApi {
#[allow(unused_variables)] #[allow(unused_variables)]
async fn write_file_text( async fn write_file_text(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
data: String, data: String,
) -> io::Result<()> { ) -> io::Result<()> {
@ -135,12 +124,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn append_file( async fn append_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
unsupported("append_file") unsupported("append_file")
} }
@ -153,7 +137,7 @@ pub trait DistantApi {
#[allow(unused_variables)] #[allow(unused_variables)]
async fn append_file_text( async fn append_file_text(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
data: String, data: String,
) -> io::Result<()> { ) -> io::Result<()> {
@ -172,7 +156,7 @@ pub trait DistantApi {
#[allow(unused_variables)] #[allow(unused_variables)]
async fn read_dir( async fn read_dir(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
depth: usize, depth: usize,
absolute: bool, absolute: bool,
@ -189,12 +173,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn create_dir( async fn create_dir(&self, ctx: DistantCtx, path: PathBuf, all: bool) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
all: bool,
) -> io::Result<()> {
unsupported("create_dir") unsupported("create_dir")
} }
@ -205,12 +184,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn copy( async fn copy(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
unsupported("copy") unsupported("copy")
} }
@ -221,12 +195,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn remove( async fn remove(&self, ctx: DistantCtx, path: PathBuf, force: bool) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
force: bool,
) -> io::Result<()> {
unsupported("remove") unsupported("remove")
} }
@ -237,12 +206,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn rename( async fn rename(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
unsupported("rename") unsupported("rename")
} }
@ -257,7 +221,7 @@ pub trait DistantApi {
#[allow(unused_variables)] #[allow(unused_variables)]
async fn watch( async fn watch(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
recursive: bool, recursive: bool,
only: Vec<ChangeKind>, only: Vec<ChangeKind>,
@ -272,7 +236,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn unwatch(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<()> { async fn unwatch(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<()> {
unsupported("unwatch") unsupported("unwatch")
} }
@ -282,7 +246,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn exists(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<bool> { async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<bool> {
unsupported("exists") unsupported("exists")
} }
@ -296,7 +260,7 @@ pub trait DistantApi {
#[allow(unused_variables)] #[allow(unused_variables)]
async fn metadata( async fn metadata(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
canonicalize: bool, canonicalize: bool,
resolve_file_type: bool, resolve_file_type: bool,
@ -314,7 +278,7 @@ pub trait DistantApi {
#[allow(unused_variables)] #[allow(unused_variables)]
async fn set_permissions( async fn set_permissions(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
permissions: Permissions, permissions: Permissions,
options: SetPermissionsOptions, options: SetPermissionsOptions,
@ -328,11 +292,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn search( async fn search(&self, ctx: DistantCtx, query: SearchQuery) -> io::Result<SearchId> {
&self,
ctx: DistantCtx<Self::LocalData>,
query: SearchQuery,
) -> io::Result<SearchId> {
unsupported("search") unsupported("search")
} }
@ -342,11 +302,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn cancel_search( async fn cancel_search(&self, ctx: DistantCtx, id: SearchId) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
id: SearchId,
) -> io::Result<()> {
unsupported("cancel_search") unsupported("cancel_search")
} }
@ -361,7 +317,7 @@ pub trait DistantApi {
#[allow(unused_variables)] #[allow(unused_variables)]
async fn proc_spawn( async fn proc_spawn(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
cmd: String, cmd: String,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
@ -376,7 +332,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn proc_kill(&self, ctx: DistantCtx<Self::LocalData>, id: ProcessId) -> io::Result<()> { async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> {
unsupported("proc_kill") unsupported("proc_kill")
} }
@ -387,12 +343,7 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn proc_stdin( async fn proc_stdin(&self, ctx: DistantCtx, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
id: ProcessId,
data: Vec<u8>,
) -> io::Result<()> {
unsupported("proc_stdin") unsupported("proc_stdin")
} }
@ -405,7 +356,7 @@ pub trait DistantApi {
#[allow(unused_variables)] #[allow(unused_variables)]
async fn proc_resize_pty( async fn proc_resize_pty(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
id: ProcessId, id: ProcessId,
size: PtySize, size: PtySize,
) -> io::Result<()> { ) -> io::Result<()> {
@ -416,32 +367,34 @@ pub trait DistantApi {
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)] #[allow(unused_variables)]
async fn system_info(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<SystemInfo> { async fn system_info(&self, ctx: DistantCtx) -> io::Result<SystemInfo> {
unsupported("system_info") unsupported("system_info")
} }
} }
#[async_trait] #[async_trait]
impl<T, D> ServerHandler for DistantApiServerHandler<T, D> impl<T> ServerHandler for DistantApiServerHandler<T>
where where
T: DistantApi<LocalData = D> + Send + Sync + 'static, T: DistantApi + Send + Sync + 'static,
D: Send + Sync + 'static,
{ {
type LocalData = D;
type Request = protocol::Msg<protocol::Request>; type Request = protocol::Msg<protocol::Request>;
type Response = protocol::Msg<protocol::Response>; type Response = protocol::Msg<protocol::Response>;
/// Overridden to leverage [`DistantApi`] implementation of `on_accept` /// Overridden to leverage [`DistantApi`] implementation of `on_connect`.
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
T::on_accept(&self.api, ctx).await T::on_connect(&self.api, id).await
}
/// Overridden to leverage [`DistantApi`] implementation of `on_disconnect`.
async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
T::on_disconnect(&self.api, id).await
} }
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
let ServerCtx { let RequestCtx {
connection_id, connection_id,
request, request,
reply, reply,
local_data,
} = ctx; } = ctx;
// Convert our reply to a queued reply so we can ensure that the result // Convert our reply to a queued reply so we can ensure that the result
@ -454,7 +407,6 @@ where
let ctx = DistantCtx { let ctx = DistantCtx {
connection_id, connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())), reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data,
}; };
let data = handle_request(Arc::clone(&self.api), ctx, data).await; let data = handle_request(Arc::clone(&self.api), ctx, data).await;
@ -485,7 +437,6 @@ where
let ctx = DistantCtx { let ctx = DistantCtx {
connection_id, connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())), reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data: Arc::clone(&local_data),
}; };
let data = handle_request(Arc::clone(&self.api), ctx, data).await; let data = handle_request(Arc::clone(&self.api), ctx, data).await;
@ -513,7 +464,6 @@ where
let ctx = DistantCtx { let ctx = DistantCtx {
connection_id, connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())), reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data: Arc::clone(&local_data),
}; };
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
@ -560,14 +510,13 @@ where
} }
/// Processes an incoming request /// Processes an incoming request
async fn handle_request<T, D>( async fn handle_request<T>(
api: Arc<T>, api: Arc<T>,
ctx: DistantCtx<D>, ctx: DistantCtx,
request: protocol::Request, request: protocol::Request,
) -> protocol::Response ) -> protocol::Response
where where
T: DistantApi<LocalData = D> + Send + Sync, T: DistantApi + Send + Sync,
D: Send + Sync,
{ {
match request { match request {
protocol::Request::Version {} => api protocol::Request::Version {} => api

@ -11,9 +11,7 @@ use distant_net::common::{InmemoryTransport, OneshotListener};
use distant_net::server::{Server, ServerRef}; use distant_net::server::{Server, ServerRef};
/// Stands up an inmemory client and server using the given api. /// Stands up an inmemory client and server using the given api.
async fn setup( async fn setup(api: impl DistantApi + Send + Sync + 'static) -> (DistantClient, ServerRef) {
api: impl DistantApi<LocalData = ()> + Send + Sync + 'static,
) -> (DistantClient, Box<dyn ServerRef>) {
let (t1, t2) = InmemoryTransport::pair(100); let (t1, t2) = InmemoryTransport::pair(100);
let server = Server::new() let server = Server::new()
@ -33,22 +31,17 @@ async fn setup(
} }
mod single { mod single {
use super::*;
use test_log::test; use test_log::test;
use super::*;
#[test(tokio::test)] #[test(tokio::test)]
async fn should_support_single_request_returning_error() { async fn should_support_single_request_returning_error() {
struct TestDistantApi; struct TestDistantApi;
#[async_trait] #[async_trait]
impl DistantApi for TestDistantApi { impl DistantApi for TestDistantApi {
type LocalData = (); async fn read_file(&self, _ctx: DistantCtx, _path: PathBuf) -> io::Result<Vec<u8>> {
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
_path: PathBuf,
) -> io::Result<Vec<u8>> {
Err(io::Error::new(io::ErrorKind::NotFound, "test error")) Err(io::Error::new(io::ErrorKind::NotFound, "test error"))
} }
} }
@ -66,13 +59,7 @@ mod single {
#[async_trait] #[async_trait]
impl DistantApi for TestDistantApi { impl DistantApi for TestDistantApi {
type LocalData = (); async fn read_file(&self, _ctx: DistantCtx, _path: PathBuf) -> io::Result<Vec<u8>> {
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
_path: PathBuf,
) -> io::Result<Vec<u8>> {
Ok(b"hello world".to_vec()) Ok(b"hello world".to_vec())
} }
} }
@ -85,25 +72,21 @@ mod single {
} }
mod batch_parallel { mod batch_parallel {
use super::*; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use distant_net::common::Request; use distant_net::common::Request;
use distant_protocol::{Msg, Request as RequestPayload}; use distant_protocol::{Msg, Request as RequestPayload};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use test_log::test; use test_log::test;
use super::*;
#[test(tokio::test)] #[test(tokio::test)]
async fn should_support_multiple_requests_running_in_parallel() { async fn should_support_multiple_requests_running_in_parallel() {
struct TestDistantApi; struct TestDistantApi;
#[async_trait] #[async_trait]
impl DistantApi for TestDistantApi { impl DistantApi for TestDistantApi {
type LocalData = (); async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "slow" { if path.to_str().unwrap() == "slow" {
tokio::time::sleep(Duration::from_millis(500)).await; tokio::time::sleep(Duration::from_millis(500)).await;
} }
@ -155,13 +138,7 @@ mod batch_parallel {
#[async_trait] #[async_trait]
impl DistantApi for TestDistantApi { impl DistantApi for TestDistantApi {
type LocalData = (); async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "fail" { if path.to_str().unwrap() == "fail" {
return Err(io::Error::new(io::ErrorKind::Other, "test error")); return Err(io::Error::new(io::ErrorKind::Other, "test error"));
} }
@ -211,25 +188,21 @@ mod batch_parallel {
} }
mod batch_sequence { mod batch_sequence {
use super::*; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use distant_net::common::Request; use distant_net::common::Request;
use distant_protocol::{Msg, Request as RequestPayload}; use distant_protocol::{Msg, Request as RequestPayload};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use test_log::test; use test_log::test;
use super::*;
#[test(tokio::test)] #[test(tokio::test)]
async fn should_support_multiple_requests_running_in_sequence() { async fn should_support_multiple_requests_running_in_sequence() {
struct TestDistantApi; struct TestDistantApi;
#[async_trait] #[async_trait]
impl DistantApi for TestDistantApi { impl DistantApi for TestDistantApi {
type LocalData = (); async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "slow" { if path.to_str().unwrap() == "slow" {
tokio::time::sleep(Duration::from_millis(500)).await; tokio::time::sleep(Duration::from_millis(500)).await;
} }
@ -284,13 +257,7 @@ mod batch_sequence {
#[async_trait] #[async_trait]
impl DistantApi for TestDistantApi { impl DistantApi for TestDistantApi {
type LocalData = (); async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "fail" { if path.to_str().unwrap() == "fail" {
return Err(io::Error::new(io::ErrorKind::Other, "test error")); return Err(io::Error::new(io::ErrorKind::Other, "test error"));
} }

@ -39,13 +39,7 @@ impl Api {
#[async_trait] #[async_trait]
impl DistantApi for Api { impl DistantApi for Api {
type LocalData = (); async fn read_file(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
async fn read_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
debug!( debug!(
"[Conn {}] Reading bytes from file {:?}", "[Conn {}] Reading bytes from file {:?}",
ctx.connection_id, path ctx.connection_id, path
@ -54,11 +48,7 @@ impl DistantApi for Api {
tokio::fs::read(path).await tokio::fs::read(path).await
} }
async fn read_file_text( async fn read_file_text(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<String> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<String> {
debug!( debug!(
"[Conn {}] Reading text from file {:?}", "[Conn {}] Reading text from file {:?}",
ctx.connection_id, path ctx.connection_id, path
@ -67,12 +57,7 @@ impl DistantApi for Api {
tokio::fs::read_to_string(path).await tokio::fs::read_to_string(path).await
} }
async fn write_file( async fn write_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Writing bytes to file {:?}", "[Conn {}] Writing bytes to file {:?}",
ctx.connection_id, path ctx.connection_id, path
@ -83,7 +68,7 @@ impl DistantApi for Api {
async fn write_file_text( async fn write_file_text(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
data: String, data: String,
) -> io::Result<()> { ) -> io::Result<()> {
@ -95,12 +80,7 @@ impl DistantApi for Api {
tokio::fs::write(path, data).await tokio::fs::write(path, data).await
} }
async fn append_file( async fn append_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Appending bytes to file {:?}", "[Conn {}] Appending bytes to file {:?}",
ctx.connection_id, path ctx.connection_id, path
@ -116,7 +96,7 @@ impl DistantApi for Api {
async fn append_file_text( async fn append_file_text(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
data: String, data: String,
) -> io::Result<()> { ) -> io::Result<()> {
@ -135,7 +115,7 @@ impl DistantApi for Api {
async fn read_dir( async fn read_dir(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
depth: usize, depth: usize,
absolute: bool, absolute: bool,
@ -228,12 +208,7 @@ impl DistantApi for Api {
Ok((entries, errors)) Ok((entries, errors))
} }
async fn create_dir( async fn create_dir(&self, ctx: DistantCtx, path: PathBuf, all: bool) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
all: bool,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Creating directory {:?} {{all: {}}}", "[Conn {}] Creating directory {:?} {{all: {}}}",
ctx.connection_id, path, all ctx.connection_id, path, all
@ -245,12 +220,7 @@ impl DistantApi for Api {
} }
} }
async fn remove( async fn remove(&self, ctx: DistantCtx, path: PathBuf, force: bool) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
force: bool,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Removing {:?} {{force: {}}}", "[Conn {}] Removing {:?} {{force: {}}}",
ctx.connection_id, path, force ctx.connection_id, path, force
@ -267,12 +237,7 @@ impl DistantApi for Api {
} }
} }
async fn copy( async fn copy(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Copying {:?} to {:?}", "[Conn {}] Copying {:?} to {:?}",
ctx.connection_id, src, dst ctx.connection_id, src, dst
@ -329,12 +294,7 @@ impl DistantApi for Api {
Ok(()) Ok(())
} }
async fn rename( async fn rename(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Renaming {:?} to {:?}", "[Conn {}] Renaming {:?} to {:?}",
ctx.connection_id, src, dst ctx.connection_id, src, dst
@ -344,7 +304,7 @@ impl DistantApi for Api {
async fn watch( async fn watch(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
recursive: bool, recursive: bool,
only: Vec<ChangeKind>, only: Vec<ChangeKind>,
@ -372,7 +332,7 @@ impl DistantApi for Api {
Ok(()) Ok(())
} }
async fn unwatch(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<()> { async fn unwatch(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<()> {
debug!("[Conn {}] Unwatching {:?}", ctx.connection_id, path); debug!("[Conn {}] Unwatching {:?}", ctx.connection_id, path);
self.state self.state
@ -382,7 +342,7 @@ impl DistantApi for Api {
Ok(()) Ok(())
} }
async fn exists(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<bool> { async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<bool> {
debug!("[Conn {}] Checking if {:?} exists", ctx.connection_id, path); debug!("[Conn {}] Checking if {:?} exists", ctx.connection_id, path);
// Following experimental `std::fs::try_exists`, which checks the error kind of the // Following experimental `std::fs::try_exists`, which checks the error kind of the
@ -396,7 +356,7 @@ impl DistantApi for Api {
async fn metadata( async fn metadata(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
canonicalize: bool, canonicalize: bool,
resolve_file_type: bool, resolve_file_type: bool,
@ -469,7 +429,7 @@ impl DistantApi for Api {
async fn set_permissions( async fn set_permissions(
&self, &self,
_ctx: DistantCtx<Self::LocalData>, _ctx: DistantCtx,
path: PathBuf, path: PathBuf,
permissions: Permissions, permissions: Permissions,
options: SetPermissionsOptions, options: SetPermissionsOptions,
@ -596,11 +556,7 @@ impl DistantApi for Api {
} }
} }
async fn search( async fn search(&self, ctx: DistantCtx, query: SearchQuery) -> io::Result<SearchId> {
&self,
ctx: DistantCtx<Self::LocalData>,
query: SearchQuery,
) -> io::Result<SearchId> {
debug!( debug!(
"[Conn {}] Performing search via {query:?}", "[Conn {}] Performing search via {query:?}",
ctx.connection_id, ctx.connection_id,
@ -609,11 +565,7 @@ impl DistantApi for Api {
self.state.search.start(query, ctx.reply).await self.state.search.start(query, ctx.reply).await
} }
async fn cancel_search( async fn cancel_search(&self, ctx: DistantCtx, id: SearchId) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
id: SearchId,
) -> io::Result<()> {
debug!("[Conn {}] Cancelling search {id}", ctx.connection_id,); debug!("[Conn {}] Cancelling search {id}", ctx.connection_id,);
self.state.search.cancel(id).await self.state.search.cancel(id).await
@ -621,7 +573,7 @@ impl DistantApi for Api {
async fn proc_spawn( async fn proc_spawn(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
cmd: String, cmd: String,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
@ -637,17 +589,12 @@ impl DistantApi for Api {
.await .await
} }
async fn proc_kill(&self, ctx: DistantCtx<Self::LocalData>, id: ProcessId) -> io::Result<()> { async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> {
debug!("[Conn {}] Killing process {}", ctx.connection_id, id); debug!("[Conn {}] Killing process {}", ctx.connection_id, id);
self.state.process.kill(id).await self.state.process.kill(id).await
} }
async fn proc_stdin( async fn proc_stdin(&self, ctx: DistantCtx, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
id: ProcessId,
data: Vec<u8>,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Sending stdin to process {}", "[Conn {}] Sending stdin to process {}",
ctx.connection_id, id ctx.connection_id, id
@ -657,7 +604,7 @@ impl DistantApi for Api {
async fn proc_resize_pty( async fn proc_resize_pty(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
id: ProcessId, id: ProcessId,
size: PtySize, size: PtySize,
) -> io::Result<()> { ) -> io::Result<()> {
@ -668,7 +615,7 @@ impl DistantApi for Api {
self.state.process.resize_pty(id, size).await self.state.process.resize_pty(id, size).await
} }
async fn system_info(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<SystemInfo> { async fn system_info(&self, ctx: DistantCtx) -> io::Result<SystemInfo> {
debug!("[Conn {}] Reading system information", ctx.connection_id); debug!("[Conn {}] Reading system information", ctx.connection_id);
Ok(SystemInfo { Ok(SystemInfo {
family: env::consts::FAMILY.to_string(), family: env::consts::FAMILY.to_string(),
@ -685,7 +632,7 @@ impl DistantApi for Api {
}) })
} }
async fn version(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Version> { async fn version(&self, ctx: DistantCtx) -> io::Result<Version> {
debug!("[Conn {}] Querying version", ctx.connection_id); debug!("[Conn {}] Querying version", ctx.connection_id);
Ok(Version { Ok(Version {
@ -698,11 +645,10 @@ impl DistantApi for Api {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use assert_fs::prelude::*; use assert_fs::prelude::*;
use distant_core::net::server::{ConnectionCtx, Reply}; use distant_core::net::server::Reply;
use distant_core::protocol::Response; use distant_core::protocol::Response;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use predicates::prelude::*; use predicates::prelude::*;
@ -773,7 +719,7 @@ mod tests {
const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100); const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100);
async fn setup(buffer: usize) -> (Api, DistantCtx<()>, mpsc::Receiver<Response>) { async fn setup(buffer: usize) -> (Api, DistantCtx, mpsc::Receiver<Response>) {
let api = Api::initialize(Config { let api = Api::initialize(Config {
watch: WatchConfig { watch: WatchConfig {
debounce_timeout: DEBOUNCE_TIMEOUT, debounce_timeout: DEBOUNCE_TIMEOUT,
@ -784,19 +730,10 @@ mod tests {
let (reply, rx) = make_reply(buffer); let (reply, rx) = make_reply(buffer);
let connection_id = rand::random(); let connection_id = rand::random();
DistantApi::on_accept( DistantApi::on_connect(&api, connection_id).await.unwrap();
&api,
ConnectionCtx {
connection_id,
local_data: &mut (),
},
)
.await
.unwrap();
let ctx = DistantCtx { let ctx = DistantCtx {
connection_id, connection_id,
reply, reply,
local_data: Arc::new(()),
}; };
(api, ctx, rx) (api, ctx, rx)
} }
@ -1683,7 +1620,6 @@ mod tests {
let ctx = DistantCtx { let ctx = DistantCtx {
connection_id: ctx_1.connection_id, connection_id: ctx_1.connection_id,
reply, reply,
local_data: Arc::clone(&ctx_1.local_data),
}; };
(ctx, rx) (ctx, rx)
}; };
@ -2662,7 +2598,6 @@ mod tests {
let ctx = DistantCtx { let ctx = DistantCtx {
connection_id: ctx_1.connection_id, connection_id: ctx_1.connection_id,
reply, reply,
local_data: Arc::clone(&ctx_1.local_data),
}; };
(ctx, rx) (ctx, rx)
}; };
@ -2723,7 +2658,6 @@ mod tests {
let ctx = DistantCtx { let ctx = DistantCtx {
connection_id: ctx_1.connection_id, connection_id: ctx_1.connection_id,
reply, reply,
local_data: Arc::clone(&ctx_1.local_data),
}; };
(ctx, rx) (ctx, rx)
}; };

@ -9,10 +9,10 @@ mod config;
mod constants; mod constants;
pub use api::Api; pub use api::Api;
pub use config::*; pub use config::*;
use distant_core::{DistantApi, DistantApiServerHandler}; use distant_core::DistantApiServerHandler;
/// Implementation of [`DistantApiServerHandler`] using [`Api`]. /// Implementation of [`DistantApiServerHandler`] using [`Api`].
pub type Handler = DistantApiServerHandler<Api, <Api as DistantApi>::LocalData>; pub type Handler = DistantApiServerHandler<Api>;
/// Initializes a new [`Handler`]. /// Initializes a new [`Handler`].
pub fn new_handler(config: Config) -> std::io::Result<Handler> { pub fn new_handler(config: Config) -> std::io::Result<Handler> {

@ -216,9 +216,7 @@ impl UntypedClient {
// If we have flagged that a reconnect is needed, attempt to do so // If we have flagged that a reconnect is needed, attempt to do so
if needs_reconnect { if needs_reconnect {
info!("Client encountered issue, attempting to reconnect"); info!("Client encountered issue, attempting to reconnect");
if log::log_enabled!(log::Level::Debug) { debug!("Using strategy {reconnect_strategy:?}");
debug!("Using strategy {reconnect_strategy:?}");
}
match reconnect_strategy.reconnect(&mut connection).await { match reconnect_strategy.reconnect(&mut connection).await {
Ok(()) => { Ok(()) => {
info!("Client successfully reconnected!"); info!("Client successfully reconnected!");
@ -236,7 +234,7 @@ impl UntypedClient {
macro_rules! silence_needs_reconnect { macro_rules! silence_needs_reconnect {
() => {{ () => {{
debug!( info!(
"Client exceeded {}s without server activity, so attempting to reconnect", "Client exceeded {}s without server activity, so attempting to reconnect",
silence_duration.as_secs_f32(), silence_duration.as_secs_f32(),
); );
@ -260,7 +258,7 @@ impl UntypedClient {
let ready = tokio::select! { let ready = tokio::select! {
// NOTE: This should NEVER return None as we never allow the channel to close. // NOTE: This should NEVER return None as we never allow the channel to close.
cb = shutdown_rx.recv() => { cb = shutdown_rx.recv() => {
debug!("Client got shutdown signal, so exiting event loop"); info!("Client got shutdown signal, so exiting event loop");
let cb = cb.expect("Impossible: shutdown channel closed!"); let cb = cb.expect("Impossible: shutdown channel closed!");
let _ = cb.send(Ok(())); let _ = cb.send(Ok(()));
watcher_tx.send_replace(ConnectionState::Disconnected); watcher_tx.send_replace(ConnectionState::Disconnected);
@ -335,7 +333,7 @@ impl UntypedClient {
} }
Ok(None) => { Ok(None) => {
debug!("Connection closed"); info!("Connection closed");
needs_reconnect = true; needs_reconnect = true;
watcher_tx.send_replace(ConnectionState::Reconnecting); watcher_tx.send_replace(ConnectionState::Reconnecting);
continue; continue;

@ -3,13 +3,13 @@ mod request;
mod response; mod response;
mod value; mod value;
use std::io::Cursor;
pub use header::*; pub use header::*;
pub use request::*; pub use request::*;
pub use response::*; pub use response::*;
pub use value::*; pub use value::*;
use std::io::Cursor;
/// Represents a generic id type /// Represents a generic id type
pub type Id = String; pub type Id = String;
@ -257,9 +257,10 @@ mod tests {
use super::*; use super::*;
mod read_str_bytes { mod read_str_bytes {
use super::*;
use test_log::test; use test_log::test;
use super::*;
#[test] #[test]
fn should_fail_if_input_is_empty() { fn should_fail_if_input_is_empty() {
let input = read_str_bytes(&[]).unwrap_err(); let input = read_str_bytes(&[]).unwrap_err();
@ -282,9 +283,10 @@ mod tests {
} }
mod read_key_eq { mod read_key_eq {
use super::*;
use test_log::test; use test_log::test;
use super::*;
#[test] #[test]
fn should_fail_if_input_is_empty() { fn should_fail_if_input_is_empty() {
let input = read_key_eq(&[], "key").unwrap_err(); let input = read_key_eq(&[], "key").unwrap_err();
@ -338,9 +340,10 @@ mod tests {
} }
mod read_header_bytes { mod read_header_bytes {
use super::*;
use test_log::test; use test_log::test;
use super::*;
#[test] #[test]
fn should_fail_if_input_is_empty() { fn should_fail_if_input_is_empty() {
let input = vec![]; let input = vec![];
@ -527,9 +530,10 @@ mod tests {
} }
mod find_msgpack_byte_len { mod find_msgpack_byte_len {
use super::*;
use test_log::test; use test_log::test;
use super::*;
#[test] #[test]
fn should_return_none_if_input_is_empty() { fn should_return_none_if_input_is_empty() {
let input = vec![]; let input = vec![];

@ -1,10 +1,12 @@
use crate::common::{utils, Value}; use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::{fmt, io};
use derive_more::IntoIterator; use derive_more::IntoIterator;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io; use crate::common::{utils, Value};
use std::ops::{Deref, DerefMut};
/// Generates a new [`Header`] of key/value pairs based on literals. /// Generates a new [`Header`] of key/value pairs based on literals.
/// ///
@ -90,3 +92,18 @@ impl DerefMut for Header {
&mut self.0 &mut self.0
} }
} }
impl fmt::Display for Header {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{{")?;
for (key, value) in self.0.iter() {
let value = serde_json::to_string(value).unwrap_or_else(|_| String::from("--"));
write!(f, "\"{key}\" = {value}")?;
}
write!(f, "}}")?;
Ok(())
}
}

@ -1,10 +1,12 @@
use crate::common::utils;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::borrow::Cow; use std::borrow::Cow;
use std::io; use std::io;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use crate::common::utils;
/// Generic value type for data passed through header. /// Generic value type for data passed through header.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)] #[serde(transparent)]

@ -12,7 +12,7 @@ use crate::manager::{
ConnectionInfo, ConnectionList, ManagerAuthenticationId, ManagerCapabilities, ManagerChannelId, ConnectionInfo, ConnectionList, ManagerAuthenticationId, ManagerCapabilities, ManagerChannelId,
ManagerRequest, ManagerResponse, ManagerRequest, ManagerResponse,
}; };
use crate::server::{Server, ServerCtx, ServerHandler}; use crate::server::{RequestCtx, Server, ServerHandler};
mod authentication; mod authentication;
pub use authentication::*; pub use authentication::*;
@ -31,6 +31,10 @@ pub struct ManagerServer {
/// Configuration settings for the server /// Configuration settings for the server
config: Config, config: Config,
/// Holds on to open channels feeding data back from a server to some connected client,
/// enabling us to cancel the tasks on demand
channels: RwLock<HashMap<ManagerChannelId, ManagerChannel>>,
/// Mapping of connection id -> connection /// Mapping of connection id -> connection
connections: RwLock<HashMap<ConnectionId, ManagerConnection>>, connections: RwLock<HashMap<ConnectionId, ManagerConnection>>,
@ -46,6 +50,7 @@ impl ManagerServer {
pub fn new(config: Config) -> Server<Self> { pub fn new(config: Config) -> Server<Self> {
Server::new().handler(Self { Server::new().handler(Self {
config, config,
channels: RwLock::new(HashMap::new()),
connections: RwLock::new(HashMap::new()), connections: RwLock::new(HashMap::new()),
registry: Arc::new(RwLock::new(HashMap::new())), registry: Arc::new(RwLock::new(HashMap::new())),
}) })
@ -177,104 +182,120 @@ impl ManagerServer {
} }
} }
#[derive(Default)]
pub struct DistantManagerServerConnection {
/// Holds on to open channels feeding data back from a server to some connected client,
/// enabling us to cancel the tasks on demand
channels: RwLock<HashMap<ManagerChannelId, ManagerChannel>>,
}
#[async_trait] #[async_trait]
impl ServerHandler for ManagerServer { impl ServerHandler for ManagerServer {
type LocalData = DistantManagerServerConnection;
type Request = ManagerRequest; type Request = ManagerRequest;
type Response = ManagerResponse; type Response = ManagerResponse;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
let ServerCtx { debug!("manager::on_request({ctx:?})");
let RequestCtx {
connection_id, connection_id,
request, request,
reply, reply,
local_data,
} = ctx; } = ctx;
let response = match request.payload { let response = match request.payload {
ManagerRequest::Capabilities {} => match self.capabilities().await { ManagerRequest::Capabilities {} => {
Ok(supported) => ManagerResponse::Capabilities { supported }, debug!("Looking up capabilities");
Err(x) => ManagerResponse::from(x), match self.capabilities().await {
}, Ok(supported) => ManagerResponse::Capabilities { supported },
Err(x) => ManagerResponse::from(x),
}
}
ManagerRequest::Launch { ManagerRequest::Launch {
destination, destination,
options, options,
} => match self } => {
.launch( info!("Launching {destination} with {options}");
*destination, match self
options, .launch(
ManagerAuthenticator { *destination,
reply: reply.clone(), options,
registry: Arc::clone(&self.registry), ManagerAuthenticator {
}, reply: reply.clone(),
) registry: Arc::clone(&self.registry),
.await },
{ )
Ok(destination) => ManagerResponse::Launched { destination }, .await
Err(x) => ManagerResponse::from(x), {
}, Ok(destination) => ManagerResponse::Launched { destination },
Err(x) => ManagerResponse::from(x),
}
}
ManagerRequest::Connect { ManagerRequest::Connect {
destination, destination,
options, options,
} => match self } => {
.connect( info!("Connecting to {destination} with {options}");
*destination, match self
options, .connect(
ManagerAuthenticator { *destination,
reply: reply.clone(), options,
registry: Arc::clone(&self.registry), ManagerAuthenticator {
}, reply: reply.clone(),
) registry: Arc::clone(&self.registry),
.await },
{ )
Ok(id) => ManagerResponse::Connected { id }, .await
Err(x) => ManagerResponse::from(x), {
}, Ok(id) => ManagerResponse::Connected { id },
Err(x) => ManagerResponse::from(x),
}
}
ManagerRequest::Authenticate { id, msg } => { ManagerRequest::Authenticate { id, msg } => {
trace!("Retrieving authentication callback registry");
match self.registry.write().await.remove(&id) { match self.registry.write().await.remove(&id) {
Some(cb) => match cb.send(msg) { Some(cb) => {
Ok(_) => return, trace!("Sending {msg:?} through authentication callback");
Err(_) => ManagerResponse::Error { match cb.send(msg) {
description: "Unable to forward authentication callback".to_string(), Ok(_) => return,
}, Err(_) => ManagerResponse::Error {
}, description: "Unable to forward authentication callback"
.to_string(),
},
}
}
None => ManagerResponse::from(io::Error::new( None => ManagerResponse::from(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid authentication id", "Invalid authentication id",
)), )),
} }
} }
ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) { ManagerRequest::OpenChannel { id } => {
Some(connection) => match connection.open_channel(reply.clone()) { debug!("Attempting to retrieve connection {id}");
Ok(channel) => { match self.connections.read().await.get(&id) {
debug!("[Conn {id}] Channel {} has been opened", channel.id()); Some(connection) => {
let id = channel.id(); debug!("Opening channel through connection {id}");
local_data.channels.write().await.insert(id, channel); match connection.open_channel(reply.clone()) {
ManagerResponse::ChannelOpened { id } Ok(channel) => {
info!("[Conn {id}] Channel {} has been opened", channel.id());
let id = channel.id();
self.channels.write().await.insert(id, channel);
ManagerResponse::ChannelOpened { id }
}
Err(x) => ManagerResponse::from(x),
}
} }
Err(x) => ManagerResponse::from(x), None => ManagerResponse::from(io::Error::new(
}, io::ErrorKind::NotConnected,
None => ManagerResponse::from(io::Error::new( "Connection does not exist",
io::ErrorKind::NotConnected, )),
"Connection does not exist", }
)), }
},
ManagerRequest::Channel { id, request } => { ManagerRequest::Channel { id, request } => {
match local_data.channels.read().await.get(&id) { debug!("Attempting to retrieve channel {id}");
match self.channels.read().await.get(&id) {
// TODO: For now, we are NOT sending back a response to acknowledge // TODO: For now, we are NOT sending back a response to acknowledge
// a successful channel send. We could do this in order for // a successful channel send. We could do this in order for
// the client to listen for a complete send, but is it worth it? // the client to listen for a complete send, but is it worth it?
Some(channel) => match channel.send(request) { Some(channel) => {
Ok(_) => return, debug!("Sending {request:?} through channel {id}");
Err(x) => ManagerResponse::from(x), match channel.send(request) {
}, Ok(_) => return,
Err(x) => ManagerResponse::from(x),
}
}
None => ManagerResponse::from(io::Error::new( None => ManagerResponse::from(io::Error::new(
io::ErrorKind::NotConnected, io::ErrorKind::NotConnected,
"Channel is not open or does not exist", "Channel is not open or does not exist",
@ -282,32 +303,54 @@ impl ServerHandler for ManagerServer {
} }
} }
ManagerRequest::CloseChannel { id } => { ManagerRequest::CloseChannel { id } => {
match local_data.channels.write().await.remove(&id) { debug!("Attempting to remove channel {id}");
Some(channel) => match channel.close() { match self.channels.write().await.remove(&id) {
Ok(_) => { Some(channel) => {
debug!("Channel {id} has been closed"); debug!("Removed channel {}", channel.id());
ManagerResponse::ChannelClosed { id } match channel.close() {
Ok(_) => {
info!("Channel {id} has been closed");
ManagerResponse::ChannelClosed { id }
}
Err(x) => ManagerResponse::from(x),
} }
Err(x) => ManagerResponse::from(x), }
},
None => ManagerResponse::from(io::Error::new( None => ManagerResponse::from(io::Error::new(
io::ErrorKind::NotConnected, io::ErrorKind::NotConnected,
"Channel is not open or does not exist", "Channel is not open or does not exist",
)), )),
} }
} }
ManagerRequest::Info { id } => match self.info(id).await { ManagerRequest::Info { id } => {
Ok(info) => ManagerResponse::Info(info), debug!("Attempting to retrieve information for connection {id}");
Err(x) => ManagerResponse::from(x), match self.info(id).await {
}, Ok(info) => {
ManagerRequest::List => match self.list().await { info!("Retrieved information for connection {id}");
Ok(list) => ManagerResponse::List(list), ManagerResponse::Info(info)
Err(x) => ManagerResponse::from(x), }
}, Err(x) => ManagerResponse::from(x),
ManagerRequest::Kill { id } => match self.kill(id).await { }
Ok(()) => ManagerResponse::Killed, }
Err(x) => ManagerResponse::from(x), ManagerRequest::List => {
}, debug!("Attempting to retrieve the list of connections");
match self.list().await {
Ok(list) => {
info!("Retrieved list of connections");
ManagerResponse::List(list)
}
Err(x) => ManagerResponse::from(x),
}
}
ManagerRequest::Kill { id } => {
debug!("Attempting to kill connection {id}");
match self.kill(id).await {
Ok(()) => {
info!("Killed connection {id}");
ManagerResponse::Killed
}
Err(x) => ManagerResponse::from(x),
}
}
}; };
if let Err(x) = reply.send(response).await { if let Err(x) = reply.send(response).await {
@ -356,6 +399,7 @@ mod tests {
let server = ManagerServer { let server = ManagerServer {
config, config,
channels: RwLock::new(HashMap::new()),
connections: RwLock::new(HashMap::new()), connections: RwLock::new(HashMap::new()),
registry, registry,
}; };

@ -1,5 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::io; use std::{fmt, io};
use log::*; use log::*;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -135,6 +135,17 @@ enum Action {
}, },
} }
impl fmt::Debug for Action {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Register { id, .. } => write!(f, "Action::Register {{ id: {id}, .. }}"),
Self::Unregister { id } => write!(f, "Action::Unregister {{ id: {id} }}"),
Self::Read { .. } => write!(f, "Action::Read {{ .. }}"),
Self::Write { id, .. } => write!(f, "Action::Write {{ id: {id}, .. }}"),
}
}
}
/// Internal task to process outgoing [`UntypedRequest`]s. /// Internal task to process outgoing [`UntypedRequest`]s.
async fn request_task( async fn request_task(
id: ConnectionId, id: ConnectionId,
@ -142,10 +153,13 @@ async fn request_task(
mut rx: mpsc::UnboundedReceiver<UntypedRequest<'static>>, mut rx: mpsc::UnboundedReceiver<UntypedRequest<'static>>,
) { ) {
while let Some(req) = rx.recv().await { while let Some(req) = rx.recv().await {
trace!("[Conn {id}] Firing off request {}", req.id);
if let Err(x) = client.fire(req).await { if let Err(x) = client.fire(req).await {
error!("[Conn {id}] Failed to send request: {x}"); error!("[Conn {id}] Failed to send request: {x}");
} }
} }
trace!("[Conn {id}] Manager request task closed");
} }
/// Internal task to process incoming [`UntypedResponse`]s. /// Internal task to process incoming [`UntypedResponse`]s.
@ -155,10 +169,17 @@ async fn response_task(
tx: mpsc::UnboundedSender<Action>, tx: mpsc::UnboundedSender<Action>,
) { ) {
while let Some(res) = mailbox.next().await { while let Some(res) = mailbox.next().await {
trace!(
"[Conn {id}] Receiving response {} to request {}",
res.id,
res.origin_id
);
if let Err(x) = tx.send(Action::Read { res }) { if let Err(x) = tx.send(Action::Read { res }) {
error!("[Conn {id}] Failed to forward received response: {x}"); error!("[Conn {id}] Failed to forward received response: {x}");
} }
} }
trace!("[Conn {id}] Manager response task closed");
} }
/// Internal task to process [`Action`] items. /// Internal task to process [`Action`] items.
@ -174,6 +195,8 @@ async fn action_task(
let mut registered = HashMap::new(); let mut registered = HashMap::new();
while let Some(action) = rx.recv().await { while let Some(action) = rx.recv().await {
trace!("[Conn {id}] {action:?}");
match action { match action {
Action::Register { id, reply } => { Action::Register { id, reply } => {
registered.insert(id, reply); registered.insert(id, reply);
@ -201,9 +224,20 @@ async fn action_task(
id: channel_id, id: channel_id,
response: res, response: res,
}; };
if let Err(x) = reply.send(response).await {
error!("[Conn {id}] {x}"); // TODO: This seems to get stuck at times with some change recently,
} // so we kick this off in a new task instead. The better solution
// is to switch most of our mpsc usage to be unbounded so we
// don't need an async call. The only bounded ones should be those
// externally facing to the API user, if even that.
//
// https://github.com/chipsenkbeil/distant/issues/205
let reply = reply.clone();
tokio::spawn(async move {
if let Err(x) = reply.send(response).await {
error!("[Conn {id}] {x}");
}
});
} }
} }
Action::Write { id, mut req } => { Action::Write { id, mut req } => {
@ -217,4 +251,6 @@ async fn action_task(
} }
} }
} }
trace!("[Conn {id}] Manager action task closed");
} }

@ -9,7 +9,7 @@ use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
use tokio::sync::{broadcast, RwLock}; use tokio::sync::{broadcast, RwLock};
use crate::common::{Listener, Response, Transport}; use crate::common::{ConnectionId, Listener, Response, Transport};
mod builder; mod builder;
pub use builder::*; pub use builder::*;
@ -56,23 +56,21 @@ pub trait ServerHandler: Send {
/// Type of data sent back by the server /// Type of data sent back by the server
type Response; type Response;
/// Type of data to store locally tied to the specific connection
type LocalData: Send;
/// Invoked upon a new connection becoming established. /// Invoked upon a new connection becoming established.
///
/// ### Note
///
/// This can be useful in performing some additional initialization on the connection's local
/// data prior to it being used anywhere else.
#[allow(unused_variables)] #[allow(unused_variables)]
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
/// Invoked upon an existing connection getting dropped.
#[allow(unused_variables)]
async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
Ok(()) Ok(())
} }
/// Invoked upon receiving a request from a client. The server should process this /// Invoked upon receiving a request from a client. The server should process this
/// request, which can be found in `ctx`, and send one or more replies in response. /// request, which can be found in `ctx`, and send one or more replies in response.
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>); async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>);
} }
impl Server<()> { impl Server<()> {
@ -144,11 +142,10 @@ where
T: ServerHandler + Sync + 'static, T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static, T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static, T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{ {
/// Consumes the server, starting a task to process connections from the `listener` and /// Consumes the server, starting a task to process connections from the `listener` and
/// returning a [`ServerRef`] that can be used to control the active server instance. /// returning a [`ServerRef`] that can be used to control the active server instance.
pub fn start<L>(self, listener: L) -> io::Result<Box<dyn ServerRef>> pub fn start<L>(self, listener: L) -> io::Result<ServerRef>
where where
L: Listener + 'static, L: Listener + 'static,
L::Output: Transport + 'static, L::Output: Transport + 'static,
@ -157,7 +154,7 @@ where
let (tx, rx) = broadcast::channel(1); let (tx, rx) = broadcast::channel(1);
let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx)); let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
Ok(Box::new(GenericServerRef { shutdown: tx, task })) Ok(ServerRef { shutdown: tx, task })
} }
/// Internal task that is run to receive connections and spawn connection tasks /// Internal task that is run to receive connections and spawn connection tasks
@ -226,6 +223,9 @@ where
.verifier(Arc::downgrade(&verifier)) .verifier(Arc::downgrade(&verifier))
.spawn(), .spawn(),
); );
// Clean up current tasks being tracked
connection_tasks.retain(|task| !task.is_finished());
} }
// Once we stop listening, we still want to wait until all connections have terminated // Once we stop listening, we still want to wait until all connections have terminated
@ -257,15 +257,10 @@ mod tests {
#[async_trait] #[async_trait]
impl ServerHandler for TestServerHandler { impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = u16; type Request = u16;
type Response = String; type Response = String;
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
Ok(())
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
// Always send back "hello" // Always send back "hello"
ctx.reply.send("hello".to_string()).await.unwrap(); ctx.reply.send("hello".to_string()).await.unwrap();
} }

@ -42,7 +42,6 @@ where
T: ServerHandler + Sync + 'static, T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static, T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static, T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{ {
pub async fn start<P>(self, addr: IpAddr, port: P) -> io::Result<TcpServerRef> pub async fn start<P>(self, addr: IpAddr, port: P) -> io::Result<TcpServerRef>
where where
@ -66,17 +65,16 @@ mod tests {
use super::*; use super::*;
use crate::client::Client; use crate::client::Client;
use crate::common::Request; use crate::common::Request;
use crate::server::ServerCtx; use crate::server::RequestCtx;
pub struct TestServerHandler; pub struct TestServerHandler;
#[async_trait] #[async_trait]
impl ServerHandler for TestServerHandler { impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = String; type Request = String;
type Response = String; type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Echo back what we received // Echo back what we received
ctx.reply ctx.reply
.send(ctx.request.payload.to_string()) .send(ctx.request.payload.to_string())

@ -42,7 +42,6 @@ where
T: ServerHandler + Sync + 'static, T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static, T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static, T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{ {
pub async fn start<P>(self, path: P) -> io::Result<UnixSocketServerRef> pub async fn start<P>(self, path: P) -> io::Result<UnixSocketServerRef>
where where
@ -66,17 +65,16 @@ mod tests {
use super::*; use super::*;
use crate::client::Client; use crate::client::Client;
use crate::common::Request; use crate::common::Request;
use crate::server::ServerCtx; use crate::server::RequestCtx;
pub struct TestServerHandler; pub struct TestServerHandler;
#[async_trait] #[async_trait]
impl ServerHandler for TestServerHandler { impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = String; type Request = String;
type Response = String; type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Echo back what we received // Echo back what we received
ctx.reply ctx.reply
.send(ctx.request.payload.to_string()) .send(ctx.request.payload.to_string())

@ -42,7 +42,6 @@ where
T: ServerHandler + Sync + 'static, T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static, T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static, T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{ {
/// Start a new server at the specified address using the given codec /// Start a new server at the specified address using the given codec
pub async fn start<A>(self, addr: A) -> io::Result<WindowsPipeServerRef> pub async fn start<A>(self, addr: A) -> io::Result<WindowsPipeServerRef>
@ -77,17 +76,16 @@ mod tests {
use super::*; use super::*;
use crate::client::Client; use crate::client::Client;
use crate::common::Request; use crate::common::Request;
use crate::server::ServerCtx; use crate::server::RequestCtx;
pub struct TestServerHandler; pub struct TestServerHandler;
#[async_trait] #[async_trait]
impl ServerHandler for TestServerHandler { impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = String; type Request = String;
type Response = String; type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Echo back what we received // Echo back what we received
ctx.reply ctx.reply
.send(ctx.request.payload.to_string()) .send(ctx.request.payload.to_string())

@ -12,10 +12,7 @@ use serde::Serialize;
use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use super::{ use super::{ConnectionState, RequestCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer};
ConnectionCtx, ConnectionState, ServerCtx, ServerHandler, ServerReply, ServerState,
ShutdownTimer,
};
use crate::common::{ use crate::common::{
Backup, Connection, Frame, Interest, Keychain, Response, Transport, UntypedRequest, Backup, Connection, Frame, Interest, Keychain, Response, Transport, UntypedRequest,
}; };
@ -226,7 +223,6 @@ where
H: ServerHandler + Sync + 'static, H: ServerHandler + Sync + 'static,
H::Request: DeserializeOwned + Send + Sync + 'static, H::Request: DeserializeOwned + Send + Sync + 'static,
H::Response: Serialize + Send + 'static, H::Response: Serialize + Send + 'static,
H::LocalData: Default + Send + Sync + 'static,
T: Transport + 'static, T: Transport + 'static,
{ {
pub fn spawn(self) -> ConnectionTask { pub fn spawn(self) -> ConnectionTask {
@ -429,16 +425,11 @@ where
let id = connection.id(); let id = connection.id();
// Create local data for the connection and then process it // Create local data for the connection and then process it
debug!("[Conn {id}] Officially accepting connection"); info!("[Conn {id}] Connection established");
let mut local_data = H::LocalData::default(); if let Err(x) = await_or_shutdown!(handler.on_connect(id)) {
if let Err(x) = await_or_shutdown!(handler.on_accept(ConnectionCtx {
connection_id: id,
local_data: &mut local_data
})) {
terminate_connection!(@fatal "[Conn {id}] Accepting connection failed: {x}"); terminate_connection!(@fatal "[Conn {id}] Accepting connection failed: {x}");
} }
let local_data = Arc::new(local_data);
let mut last_heartbeat = Instant::now(); let mut last_heartbeat = Instant::now();
// Restore our connection's channels if we have them, otherwise make new ones // Restore our connection's channels if we have them, otherwise make new ones
@ -483,15 +474,22 @@ where
Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) { Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
Ok(request) => match request.to_typed_request() { Ok(request) => match request.to_typed_request() {
Ok(request) => { Ok(request) => {
if log::log_enabled!(Level::Debug) {
let debug_header = if !request.header.is_empty() {
format!(" | header {}", request.header)
} else {
String::new()
};
debug!("[Conn {id}] New request {}{debug_header}", request.id);
}
let origin_id = request.id.clone(); let origin_id = request.id.clone();
let ctx = ServerCtx { let ctx = RequestCtx {
connection_id: id, connection_id: id,
request, request,
reply: ServerReply { reply: ServerReply {
origin_id, origin_id,
tx: tx.clone(), tx: tx.clone(),
}, },
local_data: Arc::clone(&local_data),
}; };
// Spawn a new task to run the request handler so we don't block // Spawn a new task to run the request handler so we don't block
@ -500,8 +498,8 @@ where
tokio::spawn(async move { handler.on_request(ctx).await }); tokio::spawn(async move { handler.on_request(ctx).await });
} }
Err(x) => { Err(x) => {
if log::log_enabled!(Level::Trace) { if log::log_enabled!(Level::Debug) {
trace!( error!(
"[Conn {id}] Failed receiving {}", "[Conn {id}] Failed receiving {}",
String::from_utf8_lossy(&request.payload), String::from_utf8_lossy(&request.payload),
); );
@ -600,21 +598,16 @@ mod tests {
use crate::common::{ use crate::common::{
HeapSecretKey, InmemoryTransport, Ready, Reconnectable, Request, Response, HeapSecretKey, InmemoryTransport, Ready, Reconnectable, Request, Response,
}; };
use crate::server::Shutdown; use crate::server::{ConnectionId, Shutdown};
struct TestServerHandler; struct TestServerHandler;
#[async_trait] #[async_trait]
impl ServerHandler for TestServerHandler { impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = u16; type Request = u16;
type Response = String; type Response = String;
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
Ok(())
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
// Always send back "hello" // Always send back "hello"
ctx.reply.send("hello".to_string()).await.unwrap(); ctx.reply.send("hello".to_string()).await.unwrap();
} }
@ -735,18 +728,14 @@ mod tests {
#[async_trait] #[async_trait]
impl ServerHandler for BadAcceptServerHandler { impl ServerHandler for BadAcceptServerHandler {
type LocalData = ();
type Request = u16; type Request = u16;
type Response = String; type Response = String;
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "bad accept")) Err(io::Error::new(io::ErrorKind::Other, "bad connect"))
} }
async fn on_request( async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
unreachable!(); unreachable!();
} }
} }
@ -1027,20 +1016,16 @@ mod tests {
#[async_trait] #[async_trait]
impl ServerHandler for HangingAcceptServerHandler { impl ServerHandler for HangingAcceptServerHandler {
type LocalData = ();
type Request = (); type Request = ();
type Response = (); type Response = ();
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
// Wait "forever" so we can ensure that we fail at this step // Wait "forever" so we can ensure that we fail at this step
tokio::time::sleep(Duration::MAX).await; tokio::time::sleep(Duration::MAX).await;
Err(io::Error::new(io::ErrorKind::Other, "bad accept")) Err(io::Error::new(io::ErrorKind::Other, "bad connect"))
} }
async fn on_request( async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
unreachable!(); unreachable!();
} }
} }
@ -1083,19 +1068,15 @@ mod tests {
#[async_trait] #[async_trait]
impl ServerHandler for AcceptServerHandler { impl ServerHandler for AcceptServerHandler {
type LocalData = ();
type Request = (); type Request = ();
type Response = (); type Response = ();
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
self.tx.send(()).await.unwrap(); self.tx.send(()).await.unwrap();
Ok(()) Ok(())
} }
async fn on_request( async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
unreachable!(); unreachable!();
} }
} }

@ -1,28 +1,29 @@
use std::sync::Arc; use std::fmt;
use super::ServerReply; use super::ServerReply;
use crate::common::{ConnectionId, Request}; use crate::common::{ConnectionId, Request};
/// Represents contextual information for working with an inbound request /// Represents contextual information for working with an inbound request.
pub struct ServerCtx<T, U, D> { pub struct RequestCtx<T, U> {
/// Unique identifer associated with the connection that sent the request /// Unique identifer associated with the connection that sent the request.
pub connection_id: ConnectionId, pub connection_id: ConnectionId,
/// The request being handled /// The request being handled.
pub request: Request<T>, pub request: Request<T>,
/// Used to send replies back to be sent out by the server /// Used to send replies back to be sent out by the server.
pub reply: ServerReply<U>, pub reply: ServerReply<U>,
/// Reference to the connection's local data
pub local_data: Arc<D>,
} }
/// Represents contextual information for working with an inbound connection impl<T, U> fmt::Debug for RequestCtx<T, U>
pub struct ConnectionCtx<'a, D> { where
/// Unique identifer associated with the connection T: fmt::Debug,
pub connection_id: ConnectionId, {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
/// Reference to the connection's local data f.debug_struct("RequestCtx")
pub local_data: &'a mut D, .field("connection_id", &self.connection_id)
.field("request", &self.request)
.field("reply", &"...")
.finish()
}
} }

@ -1,94 +1,27 @@
use std::future::Future; use std::future::Future;
use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use tokio::task::{JoinError, JoinHandle}; use tokio::task::{JoinError, JoinHandle};
use crate::common::AsAny; /// Represents a reference to a server
pub struct ServerRef {
/// Interface to engage with a server instance.
pub trait ServerRef: AsAny + Send {
/// Returns true if the server is no longer running.
fn is_finished(&self) -> bool;
/// Sends a shutdown signal to the server.
fn shutdown(&self);
fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>>
where
Self: Sized + 'static,
{
Box::pin(async {
let task = tokio::spawn(async move {
while !self.is_finished() {
tokio::time::sleep(Duration::from_millis(100)).await;
}
});
task.await
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))
})
}
}
impl dyn ServerRef {
/// Attempts to convert this ref into a concrete ref by downcasting
pub fn as_server_ref<R: ServerRef>(&self) -> Option<&R> {
self.as_any().downcast_ref::<R>()
}
/// Attempts to convert this mutable ref into a concrete mutable ref by downcasting
pub fn as_mut_server_ref<R: ServerRef>(&mut self) -> Option<&mut R> {
self.as_mut_any().downcast_mut::<R>()
}
/// Attempts to convert this into a concrete, boxed ref by downcasting
pub fn into_boxed_server_ref<R: ServerRef>(
self: Box<Self>,
) -> Result<Box<R>, Box<dyn std::any::Any>> {
self.into_any().downcast::<R>()
}
/// Waits for the server to complete by continuously polling the finished state.
pub async fn polling_wait(&self) -> io::Result<()> {
while !self.is_finished() {
tokio::time::sleep(Duration::from_millis(100)).await;
}
Ok(())
}
}
/// Represents a generic reference to a server
pub struct GenericServerRef {
pub(crate) shutdown: broadcast::Sender<()>, pub(crate) shutdown: broadcast::Sender<()>,
pub(crate) task: JoinHandle<()>, pub(crate) task: JoinHandle<()>,
} }
/// Runtime-specific implementation of [`ServerRef`] for a [`tokio::task::JoinHandle`] impl ServerRef {
impl ServerRef for GenericServerRef { pub fn is_finished(&self) -> bool {
fn is_finished(&self) -> bool {
self.task.is_finished() self.task.is_finished()
} }
fn shutdown(&self) { pub fn shutdown(&self) {
let _ = self.shutdown.send(()); let _ = self.shutdown.send(());
} }
fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>>
where
Self: Sized + 'static,
{
Box::pin(async {
self.task
.await
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))
})
}
} }
impl Future for GenericServerRef { impl Future for ServerRef {
type Output = Result<(), JoinError>; type Output = Result<(), JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {

@ -1,36 +1,59 @@
use std::future::Future;
use std::net::IpAddr; use std::net::IpAddr;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::task::JoinError;
use super::ServerRef; use super::ServerRef;
/// Reference to a TCP server instance /// Reference to a TCP server instance.
pub struct TcpServerRef { pub struct TcpServerRef {
pub(crate) addr: IpAddr, pub(crate) addr: IpAddr,
pub(crate) port: u16, pub(crate) port: u16,
pub(crate) inner: Box<dyn ServerRef>, pub(crate) inner: ServerRef,
} }
impl TcpServerRef { impl TcpServerRef {
pub fn new(addr: IpAddr, port: u16, inner: Box<dyn ServerRef>) -> Self { pub fn new(addr: IpAddr, port: u16, inner: ServerRef) -> Self {
Self { addr, port, inner } Self { addr, port, inner }
} }
/// Returns the IP address that the listener is bound to /// Returns the IP address that the listener is bound to.
pub fn ip_addr(&self) -> IpAddr { pub fn ip_addr(&self) -> IpAddr {
self.addr self.addr
} }
/// Returns the port that the listener is bound to /// Returns the port that the listener is bound to.
pub fn port(&self) -> u16 { pub fn port(&self) -> u16 {
self.port self.port
} }
/// Consumes ref, returning inner ref.
pub fn into_inner(self) -> ServerRef {
self.inner
}
} }
impl ServerRef for TcpServerRef { impl Future for TcpServerRef {
fn is_finished(&self) -> bool { type Output = Result<(), JoinError>;
self.inner.is_finished()
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner.task).poll(cx)
} }
}
impl Deref for TcpServerRef {
type Target = ServerRef;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
fn shutdown(&self) { impl DerefMut for TcpServerRef {
self.inner.shutdown(); fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
} }
} }

@ -1,35 +1,53 @@
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::task::JoinError;
use super::ServerRef; use super::ServerRef;
/// Reference to a unix socket server instance /// Reference to a unix socket server instance.
pub struct UnixSocketServerRef { pub struct UnixSocketServerRef {
pub(crate) path: PathBuf, pub(crate) path: PathBuf,
pub(crate) inner: Box<dyn ServerRef>, pub(crate) inner: ServerRef,
} }
impl UnixSocketServerRef { impl UnixSocketServerRef {
pub fn new(path: PathBuf, inner: Box<dyn ServerRef>) -> Self { pub fn new(path: PathBuf, inner: ServerRef) -> Self {
Self { path, inner } Self { path, inner }
} }
/// Returns the path to the socket /// Returns the path to the socket.
pub fn path(&self) -> &Path { pub fn path(&self) -> &Path {
&self.path &self.path
} }
/// Consumes ref, returning inner ref /// Consumes ref, returning inner ref.
pub fn into_inner(self) -> Box<dyn ServerRef> { pub fn into_inner(self) -> ServerRef {
self.inner self.inner
} }
} }
impl ServerRef for UnixSocketServerRef { impl Future for UnixSocketServerRef {
fn is_finished(&self) -> bool { type Output = Result<(), JoinError>;
self.inner.is_finished()
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner.task).poll(cx)
} }
}
impl Deref for UnixSocketServerRef {
type Target = ServerRef;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
fn shutdown(&self) { impl DerefMut for UnixSocketServerRef {
self.inner.shutdown(); fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
} }
} }

@ -1,35 +1,53 @@
use std::ffi::{OsStr, OsString}; use std::ffi::{OsStr, OsString};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::task::JoinError;
use super::ServerRef; use super::ServerRef;
/// Reference to a unix socket server instance /// Reference to a windows pipe server instance.
pub struct WindowsPipeServerRef { pub struct WindowsPipeServerRef {
pub(crate) addr: OsString, pub(crate) addr: OsString,
pub(crate) inner: Box<dyn ServerRef>, pub(crate) inner: ServerRef,
} }
impl WindowsPipeServerRef { impl WindowsPipeServerRef {
pub fn new(addr: OsString, inner: Box<dyn ServerRef>) -> Self { pub fn new(addr: OsString, inner: ServerRef) -> Self {
Self { addr, inner } Self { addr, inner }
} }
/// Returns the addr that the listener is bound to /// Returns the addr that the listener is bound to.
pub fn addr(&self) -> &OsStr { pub fn addr(&self) -> &OsStr {
&self.addr &self.addr
} }
/// Consumes ref, returning inner ref /// Consumes ref, returning inner ref.
pub fn into_inner(self) -> Box<dyn ServerRef> { pub fn into_inner(self) -> ServerRef {
self.inner self.inner
} }
} }
impl ServerRef for WindowsPipeServerRef { impl Future for WindowsPipeServerRef {
fn is_finished(&self) -> bool { type Output = Result<(), JoinError>;
self.inner.is_finished()
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner.task).poll(cx)
} }
}
impl Deref for WindowsPipeServerRef {
type Target = ServerRef;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
fn shutdown(&self) { impl DerefMut for WindowsPipeServerRef {
self.inner.shutdown(); fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
} }
} }

@ -6,7 +6,7 @@ use distant_net::boxed_connect_handler;
use distant_net::client::Client; use distant_net::client::Client;
use distant_net::common::{Destination, InmemoryTransport, Map, OneshotListener}; use distant_net::common::{Destination, InmemoryTransport, Map, OneshotListener};
use distant_net::manager::{Config, ManagerClient, ManagerServer}; use distant_net::manager::{Config, ManagerClient, ManagerServer};
use distant_net::server::{Server, ServerCtx, ServerHandler}; use distant_net::server::{RequestCtx, Server, ServerHandler};
use log::*; use log::*;
use test_log::test; use test_log::test;
@ -14,11 +14,10 @@ struct TestServerHandler;
#[async_trait] #[async_trait]
impl ServerHandler for TestServerHandler { impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = String; type Request = String;
type Response = String; type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
ctx.reply ctx.reply
.send(format!("echo {}", ctx.request.payload)) .send(format!("echo {}", ctx.request.payload))
.await .await
@ -37,7 +36,7 @@ async fn should_be_able_to_establish_a_single_connection_and_communicate_with_a_
let (t1, t2) = InmemoryTransport::pair(100); let (t1, t2) = InmemoryTransport::pair(100);
// Spawn a server on one end and connect to it on the other // Spawn a server on one end and connect to it on the other
let _ = Server::new() let _server = Server::new()
.handler(TestServerHandler) .handler(TestServerHandler)
.verifier(Verifier::none()) .verifier(Verifier::none())
.start(OneshotListener::from_value(t2))?; .start(OneshotListener::from_value(t2))?;

@ -2,7 +2,7 @@ use async_trait::async_trait;
use distant_auth::{DummyAuthHandler, Verifier}; use distant_auth::{DummyAuthHandler, Verifier};
use distant_net::client::Client; use distant_net::client::Client;
use distant_net::common::{InmemoryTransport, OneshotListener}; use distant_net::common::{InmemoryTransport, OneshotListener};
use distant_net::server::{Server, ServerCtx, ServerHandler}; use distant_net::server::{RequestCtx, Server, ServerHandler};
use log::*; use log::*;
use test_log::test; use test_log::test;
@ -10,11 +10,10 @@ struct TestServerHandler;
#[async_trait] #[async_trait]
impl ServerHandler for TestServerHandler { impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = (u8, String); type Request = (u8, String);
type Response = String; type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
let (cnt, msg) = ctx.request.payload; let (cnt, msg) = ctx.request.payload;
for i in 0..cnt { for i in 0..cnt {
@ -30,7 +29,7 @@ impl ServerHandler for TestServerHandler {
async fn should_be_able_to_send_and_receive_typed_payloads_between_client_and_server() { async fn should_be_able_to_send_and_receive_typed_payloads_between_client_and_server() {
let (t1, t2) = InmemoryTransport::pair(100); let (t1, t2) = InmemoryTransport::pair(100);
let _ = Server::new() let _server = Server::new()
.handler(TestServerHandler) .handler(TestServerHandler)
.verifier(Verifier::none()) .verifier(Verifier::none())
.start(OneshotListener::from_value(t2)) .start(OneshotListener::from_value(t2))

@ -2,7 +2,7 @@ use async_trait::async_trait;
use distant_auth::{DummyAuthHandler, Verifier}; use distant_auth::{DummyAuthHandler, Verifier};
use distant_net::client::Client; use distant_net::client::Client;
use distant_net::common::{InmemoryTransport, OneshotListener, Request}; use distant_net::common::{InmemoryTransport, OneshotListener, Request};
use distant_net::server::{Server, ServerCtx, ServerHandler}; use distant_net::server::{RequestCtx, Server, ServerHandler};
use log::*; use log::*;
use test_log::test; use test_log::test;
@ -10,11 +10,10 @@ struct TestServerHandler;
#[async_trait] #[async_trait]
impl ServerHandler for TestServerHandler { impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = (u8, String); type Request = (u8, String);
type Response = String; type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) { async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
let (cnt, msg) = ctx.request.payload; let (cnt, msg) = ctx.request.payload;
for i in 0..cnt { for i in 0..cnt {
@ -30,7 +29,7 @@ impl ServerHandler for TestServerHandler {
async fn should_be_able_to_send_and_receive_untyped_payloads_between_client_and_server() { async fn should_be_able_to_send_and_receive_untyped_payloads_between_client_and_server() {
let (t1, t2) = InmemoryTransport::pair(100); let (t1, t2) = InmemoryTransport::pair(100);
let _ = Server::new() let _server = Server::new()
.handler(TestServerHandler) .handler(TestServerHandler)
.verifier(Verifier::none()) .verifier(Verifier::none())
.start(OneshotListener::from_value(t2)) .start(OneshotListener::from_value(t2))

@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet, VecDeque}; use std::collections::{HashMap, VecDeque};
use std::io; use std::io;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
@ -7,7 +7,6 @@ use std::time::Duration;
use async_compat::CompatExt; use async_compat::CompatExt;
use async_once_cell::OnceCell; use async_once_cell::OnceCell;
use async_trait::async_trait; use async_trait::async_trait;
use distant_core::net::server::ConnectionCtx;
use distant_core::protocol::{ use distant_core::protocol::{
Capabilities, CapabilityKind, DirEntry, Environment, FileType, Metadata, Permissions, Capabilities, CapabilityKind, DirEntry, Environment, FileType, Metadata, Permissions,
ProcessId, PtySize, SetPermissionsOptions, SystemInfo, UnixMetadata, Version, PROTOCOL_VERSION, ProcessId, PtySize, SetPermissionsOptions, SystemInfo, UnixMetadata, Version, PROTOCOL_VERSION,
@ -25,16 +24,6 @@ use crate::utils::{self, to_other_error};
/// Time after copy completes to wait for stdout/stderr to close /// Time after copy completes to wait for stdout/stderr to close
const COPY_COMPLETE_TIMEOUT: Duration = Duration::from_secs(1); const COPY_COMPLETE_TIMEOUT: Duration = Duration::from_secs(1);
#[derive(Default)]
pub struct ConnectionState {
/// List of process ids that will be killed when the connection terminates
processes: Arc<RwLock<HashSet<ProcessId>>>,
/// Internal reference to global process list for removals
/// NOTE: Initialized during `on_accept` of [`DistantApi`]
global_processes: Weak<RwLock<HashMap<ProcessId, Process>>>,
}
struct Process { struct Process {
stdin_tx: mpsc::Sender<Vec<u8>>, stdin_tx: mpsc::Sender<Vec<u8>>,
kill_tx: mpsc::Sender<()>, kill_tx: mpsc::Sender<()>,
@ -72,18 +61,7 @@ impl SshDistantApi {
#[async_trait] #[async_trait]
impl DistantApi for SshDistantApi { impl DistantApi for SshDistantApi {
type LocalData = ConnectionState; async fn read_file(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
ctx.local_data.global_processes = Arc::downgrade(&self.processes);
Ok(())
}
async fn read_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
debug!( debug!(
"[Conn {}] Reading bytes from file {:?}", "[Conn {}] Reading bytes from file {:?}",
ctx.connection_id, path ctx.connection_id, path
@ -103,11 +81,7 @@ impl DistantApi for SshDistantApi {
Ok(contents.into_bytes()) Ok(contents.into_bytes())
} }
async fn read_file_text( async fn read_file_text(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<String> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<String> {
debug!( debug!(
"[Conn {}] Reading text from file {:?}", "[Conn {}] Reading text from file {:?}",
ctx.connection_id, path ctx.connection_id, path
@ -127,12 +101,7 @@ impl DistantApi for SshDistantApi {
Ok(contents) Ok(contents)
} }
async fn write_file( async fn write_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Writing bytes to file {:?}", "[Conn {}] Writing bytes to file {:?}",
ctx.connection_id, path ctx.connection_id, path
@ -154,7 +123,7 @@ impl DistantApi for SshDistantApi {
async fn write_file_text( async fn write_file_text(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
data: String, data: String,
) -> io::Result<()> { ) -> io::Result<()> {
@ -177,12 +146,7 @@ impl DistantApi for SshDistantApi {
Ok(()) Ok(())
} }
async fn append_file( async fn append_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Appending bytes to file {:?}", "[Conn {}] Appending bytes to file {:?}",
ctx.connection_id, path ctx.connection_id, path
@ -213,7 +177,7 @@ impl DistantApi for SshDistantApi {
async fn append_file_text( async fn append_file_text(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
data: String, data: String,
) -> io::Result<()> { ) -> io::Result<()> {
@ -247,7 +211,7 @@ impl DistantApi for SshDistantApi {
async fn read_dir( async fn read_dir(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
depth: usize, depth: usize,
absolute: bool, absolute: bool,
@ -375,12 +339,7 @@ impl DistantApi for SshDistantApi {
Ok((entries, errors)) Ok((entries, errors))
} }
async fn create_dir( async fn create_dir(&self, ctx: DistantCtx, path: PathBuf, all: bool) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
all: bool,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Creating directory {:?} {{all: {}}}", "[Conn {}] Creating directory {:?} {{all: {}}}",
ctx.connection_id, path, all ctx.connection_id, path, all
@ -436,12 +395,7 @@ impl DistantApi for SshDistantApi {
Ok(()) Ok(())
} }
async fn remove( async fn remove(&self, ctx: DistantCtx, path: PathBuf, force: bool) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
force: bool,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Removing {:?} {{force: {}}}", "[Conn {}] Removing {:?} {{force: {}}}",
ctx.connection_id, path, force ctx.connection_id, path, force
@ -526,12 +480,7 @@ impl DistantApi for SshDistantApi {
Ok(()) Ok(())
} }
async fn copy( async fn copy(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Copying {:?} to {:?}", "[Conn {}] Copying {:?} to {:?}",
ctx.connection_id, src, dst ctx.connection_id, src, dst
@ -573,12 +522,7 @@ impl DistantApi for SshDistantApi {
} }
} }
async fn rename( async fn rename(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Renaming {:?} to {:?}", "[Conn {}] Renaming {:?} to {:?}",
ctx.connection_id, src, dst ctx.connection_id, src, dst
@ -594,7 +538,7 @@ impl DistantApi for SshDistantApi {
Ok(()) Ok(())
} }
async fn exists(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<bool> { async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<bool> {
debug!("[Conn {}] Checking if {:?} exists", ctx.connection_id, path); debug!("[Conn {}] Checking if {:?} exists", ctx.connection_id, path);
// NOTE: SFTP does not provide a means to check if a path exists that can be performed // NOTE: SFTP does not provide a means to check if a path exists that can be performed
@ -612,7 +556,7 @@ impl DistantApi for SshDistantApi {
async fn metadata( async fn metadata(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
canonicalize: bool, canonicalize: bool,
resolve_file_type: bool, resolve_file_type: bool,
@ -676,7 +620,7 @@ impl DistantApi for SshDistantApi {
#[allow(unreachable_code)] #[allow(unreachable_code)]
async fn set_permissions( async fn set_permissions(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
path: PathBuf, path: PathBuf,
permissions: Permissions, permissions: Permissions,
options: SetPermissionsOptions, options: SetPermissionsOptions,
@ -805,7 +749,7 @@ impl DistantApi for SshDistantApi {
async fn proc_spawn( async fn proc_spawn(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
cmd: String, cmd: String,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
@ -817,14 +761,10 @@ impl DistantApi for SshDistantApi {
); );
let global_processes = Arc::downgrade(&self.processes); let global_processes = Arc::downgrade(&self.processes);
let local_processes = Arc::downgrade(&ctx.local_data.processes);
let cleanup = |id: ProcessId| async move { let cleanup = |id: ProcessId| async move {
if let Some(processes) = Weak::upgrade(&global_processes) { if let Some(processes) = Weak::upgrade(&global_processes) {
processes.write().await.remove(&id); processes.write().await.remove(&id);
} }
if let Some(processes) = Weak::upgrade(&local_processes) {
processes.write().await.remove(&id);
}
}; };
let SpawnResult { let SpawnResult {
@ -874,7 +814,7 @@ impl DistantApi for SshDistantApi {
Ok(id) Ok(id)
} }
async fn proc_kill(&self, ctx: DistantCtx<Self::LocalData>, id: ProcessId) -> io::Result<()> { async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> {
debug!("[Conn {}] Killing process {}", ctx.connection_id, id); debug!("[Conn {}] Killing process {}", ctx.connection_id, id);
if let Some(process) = self.processes.read().await.get(&id) { if let Some(process) = self.processes.read().await.get(&id) {
@ -892,12 +832,7 @@ impl DistantApi for SshDistantApi {
)) ))
} }
async fn proc_stdin( async fn proc_stdin(&self, ctx: DistantCtx, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
&self,
ctx: DistantCtx<Self::LocalData>,
id: ProcessId,
data: Vec<u8>,
) -> io::Result<()> {
debug!( debug!(
"[Conn {}] Sending stdin to process {}", "[Conn {}] Sending stdin to process {}",
ctx.connection_id, id ctx.connection_id, id
@ -920,7 +855,7 @@ impl DistantApi for SshDistantApi {
async fn proc_resize_pty( async fn proc_resize_pty(
&self, &self,
ctx: DistantCtx<Self::LocalData>, ctx: DistantCtx,
id: ProcessId, id: ProcessId,
size: PtySize, size: PtySize,
) -> io::Result<()> { ) -> io::Result<()> {
@ -944,7 +879,7 @@ impl DistantApi for SshDistantApi {
)) ))
} }
async fn system_info(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<SystemInfo> { async fn system_info(&self, ctx: DistantCtx) -> io::Result<SystemInfo> {
// We cache each of these requested values since they should not change for the // We cache each of these requested values since they should not change for the
// lifetime of the ssh connection // lifetime of the ssh connection
static CURRENT_DIR: OnceCell<PathBuf> = OnceCell::new(); static CURRENT_DIR: OnceCell<PathBuf> = OnceCell::new();
@ -998,7 +933,7 @@ impl DistantApi for SshDistantApi {
}) })
} }
async fn version(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Version> { async fn version(&self, ctx: DistantCtx) -> io::Result<Version> {
debug!("[Conn {}] Querying capabilities", ctx.connection_id); debug!("[Conn {}] Querying capabilities", ctx.connection_id);
let mut capabilities = Capabilities::all(); let mut capabilities = Capabilities::all();

@ -722,7 +722,7 @@ impl Ssh {
} }
/// Consumes [`Ssh`] and produces a [`DistantClient`] and [`ServerRef`] pair. /// Consumes [`Ssh`] and produces a [`DistantClient`] and [`ServerRef`] pair.
pub async fn into_distant_pair(self) -> io::Result<(DistantClient, Box<dyn ServerRef>)> { pub async fn into_distant_pair(self) -> io::Result<(DistantClient, ServerRef)> {
// Exit early if not authenticated as this is a requirement // Exit early if not authenticated as this is a requirement
if !self.authenticated { if !self.authenticated {
return Err(io::Error::new( return Err(io::Error::new(

@ -13,7 +13,7 @@ pub(crate) use common::{Cache, Client, Manager};
/// Represents the primary CLI entrypoint /// Represents the primary CLI entrypoint
#[derive(Debug)] #[derive(Debug)]
pub struct Cli { pub struct Cli {
options: Options, pub options: Options,
} }
impl Cli { impl Cli {

@ -1,19 +1,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::io; use std::io;
use std::io::Write; use std::io::Write;
use std::path::Path; use std::path::{Path, PathBuf};
use std::path::PathBuf;
use std::time::Duration; use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use distant_core::net::common::{ConnectionId, Host, Map, Request, Response}; use distant_core::net::common::{ConnectionId, Host, Map, Request, Response};
use distant_core::net::manager::ManagerClient; use distant_core::net::manager::ManagerClient;
use distant_core::protocol::SearchQueryContentsMatch;
use distant_core::protocol::SearchQueryMatch;
use distant_core::protocol::SearchQueryPathMatch;
use distant_core::protocol::{ use distant_core::protocol::{
self, Capabilities, ChangeKind, ChangeKindSet, FileType, Permissions, SearchQuery, self, Capabilities, ChangeKind, ChangeKindSet, FileType, Permissions, SearchQuery,
SetPermissionsOptions, SystemInfo, SearchQueryContentsMatch, SearchQueryMatch, SearchQueryPathMatch, SetPermissionsOptions,
SystemInfo,
}; };
use distant_core::{DistantChannel, DistantChannelExt, RemoteCommand, Searcher, Watcher}; use distant_core::{DistantChannel, DistantChannelExt, RemoteCommand, Searcher, Watcher};
use log::*; use log::*;

@ -185,7 +185,7 @@ async fn async_run(cmd: ManagerSubcommand) -> CliResult {
"global".to_string() "global".to_string()
} }
); );
let manager_ref = Manager { let manager = Manager {
access, access,
config: NetManagerConfig { config: NetManagerConfig {
user, user,
@ -223,11 +223,7 @@ async fn async_run(cmd: ManagerSubcommand) -> CliResult {
.context("Failed to start manager")?; .context("Failed to start manager")?;
// Let our server run to completion // Let our server run to completion
manager_ref manager.await.context("Failed to wait on manager")?;
.as_ref()
.polling_wait()
.await
.context("Failed to wait on manager")?;
info!("Manager is shutting down"); info!("Manager is shutting down");
Ok(()) Ok(())

@ -3,7 +3,7 @@ use std::io::{self, Read, Write};
use anyhow::Context; use anyhow::Context;
use distant_core::net::auth::Verifier; use distant_core::net::auth::Verifier;
use distant_core::net::common::{Host, SecretKey32}; use distant_core::net::common::{Host, SecretKey32};
use distant_core::net::server::{Server, ServerConfig as NetServerConfig, ServerRef}; use distant_core::net::server::{Server, ServerConfig as NetServerConfig};
use distant_core::DistantSingleKeyCredentials; use distant_core::DistantSingleKeyCredentials;
use distant_local::{Config as LocalConfig, WatchConfig as LocalWatchConfig}; use distant_local::{Config as LocalConfig, WatchConfig as LocalWatchConfig};
use log::*; use log::*;
@ -212,7 +212,7 @@ async fn async_run(cmd: ServerSubcommand, _is_forked: bool) -> CliResult {
} }
// Let our server run to completion // Let our server run to completion
server.wait().await.context("Failed to wait on server")?; server.await.context("Failed to wait on server")?;
info!("Server is shutting down"); info!("Server is shutting down");
} }
} }

@ -15,7 +15,7 @@ pub struct Manager {
impl Manager { impl Manager {
/// Begin listening on the network interface specified within [`NetworkConfig`] /// Begin listening on the network interface specified within [`NetworkConfig`]
pub async fn listen(self) -> anyhow::Result<Box<dyn ServerRef>> { pub async fn listen(self) -> anyhow::Result<ServerRef> {
let user = self.config.user; let user = self.config.user;
#[cfg(unix)] #[cfg(unix)]
@ -36,7 +36,7 @@ impl Manager {
.with_context(|| format!("Failed to create socket directory {parent:?}"))?; .with_context(|| format!("Failed to create socket directory {parent:?}"))?;
} }
let boxed_ref = ManagerServer::new(self.config) let server = ManagerServer::new(self.config)
.verifier(Verifier::none()) .verifier(Verifier::none())
.start( .start(
UnixSocketListener::bind_with_permissions(socket_path, self.access.into_mode()) UnixSocketListener::bind_with_permissions(socket_path, self.access.into_mode())
@ -45,7 +45,7 @@ impl Manager {
.with_context(|| format!("Failed to start manager at socket {socket_path:?}"))?; .with_context(|| format!("Failed to start manager at socket {socket_path:?}"))?;
info!("Manager listening using unix socket @ {:?}", socket_path); info!("Manager listening using unix socket @ {:?}", socket_path);
Ok(boxed_ref) Ok(server)
} }
#[cfg(windows)] #[cfg(windows)]
@ -57,13 +57,13 @@ impl Manager {
global_paths::WINDOWS_PIPE_NAME.as_str() global_paths::WINDOWS_PIPE_NAME.as_str()
}); });
let boxed_ref = ManagerServer::new(self.config) let server = ManagerServer::new(self.config)
.verifier(Verifier::none()) .verifier(Verifier::none())
.start(WindowsPipeListener::bind_local(pipe_name)?) .start(WindowsPipeListener::bind_local(pipe_name)?)
.with_context(|| format!("Failed to start manager at pipe {pipe_name:?}"))?; .with_context(|| format!("Failed to start manager at pipe {pipe_name:?}"))?;
info!("Manager listening using windows pipe @ {:?}", pipe_name); info!("Manager listening using windows pipe @ {:?}", pipe_name);
Ok(boxed_ref) Ok(server)
} }
} }
} }

@ -7,6 +7,7 @@ fn main() -> MainResult {
Err(x) => return MainResult::from(x), Err(x) => return MainResult::from(x),
}; };
let _logger = cli.init_logger(); let _logger = cli.init_logger();
MainResult::from(cli.run()) MainResult::from(cli.run())
} }

@ -28,6 +28,10 @@ pub struct Options {
#[clap(flatten)] #[clap(flatten)]
pub logging: LoggingSettings, pub logging: LoggingSettings,
#[cfg(feature = "tracing")]
#[clap(long, global = true)]
pub tracing: bool,
/// Configuration file to load instead of the default paths /// Configuration file to load instead of the default paths
#[clap(short = 'c', long = "config", global = true, value_parser)] #[clap(short = 'c', long = "config", global = true, value_parser)]
config_path: Option<PathBuf>, config_path: Option<PathBuf>,

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test(tokio::test)] #[test(tokio::test)]

@ -5,7 +5,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
/// Creates a directory in the form /// Creates a directory in the form
/// ///

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test(tokio::test)] #[test(tokio::test)]

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::{json, Value}; use serde_json::{json, Value};
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -2,8 +2,8 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*;
use crate::cli::scripts::*; use crate::cli::scripts::*;
use crate::common::fixtures::*;
fn make_cmd(args: Vec<&str>) -> String { fn make_cmd(args: Vec<&str>) -> String {
format!( format!(

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test(tokio::test)] #[test(tokio::test)]

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test(tokio::test)] #[test(tokio::test)]

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test(tokio::test)] #[test(tokio::test)]

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test(tokio::test)] #[test(tokio::test)]

@ -5,7 +5,7 @@ use rstest::*;
use serde_json::json; use serde_json::json;
use test_log::test; use test_log::test;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
async fn wait_a_bit() { async fn wait_a_bit() {
wait_millis(250).await; wait_millis(250).await;

@ -2,7 +2,7 @@ use assert_fs::prelude::*;
use predicates::prelude::*; use predicates::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -1,7 +1,7 @@
use assert_fs::prelude::*; use assert_fs::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test_log::test] #[test_log::test]

@ -2,7 +2,7 @@ use assert_fs::prelude::*;
use predicates::prelude::*; use predicates::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test_log::test] #[test_log::test]

@ -2,8 +2,8 @@ use assert_fs::prelude::*;
use predicates::prelude::*; use predicates::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
use crate::cli::utils::regex_pred; use crate::common::utils::regex_pred;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -4,8 +4,8 @@ use assert_fs::prelude::*;
use predicates::prelude::*; use predicates::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
use crate::cli::utils::regex_pred; use crate::common::utils::regex_pred;
/// Creates a directory in the form /// Creates a directory in the form
/// ///

@ -3,7 +3,7 @@ use indoc::indoc;
use predicates::prelude::*; use predicates::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = indoc! {r#" const FILE_CONTENTS: &str = indoc! {r#"
some text some text

@ -2,7 +2,7 @@ use assert_fs::prelude::*;
use predicates::prelude::*; use predicates::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test_log::test] #[test_log::test]

@ -2,7 +2,7 @@ use assert_fs::prelude::*;
use predicates::prelude::*; use predicates::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#" const FILE_CONTENTS: &str = r#"
some text some text

@ -3,7 +3,7 @@ use indoc::indoc;
use predicates::Predicate; use predicates::Predicate;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const SEARCH_RESULTS_REGEX: &str = indoc! {r" const SEARCH_RESULTS_REGEX: &str = indoc! {r"
.*?[\\/]file1.txt .*?[\\/]file1.txt

@ -4,8 +4,8 @@ use std::time::Duration;
use assert_fs::prelude::*; use assert_fs::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
use crate::cli::utils::ThreadedReader; use crate::common::utils::ThreadedReader;
fn wait_a_bit() { fn wait_a_bit() {
wait_millis(250); wait_millis(250);

@ -3,7 +3,7 @@ use indoc::indoc;
use predicates::prelude::*; use predicates::prelude::*;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const FILE_CONTENTS: &str = indoc! {r#" const FILE_CONTENTS: &str = indoc! {r#"
some text some text

@ -1,8 +1,8 @@
use rstest::*; use rstest::*;
use crate::cli::fixtures::*;
use crate::cli::scripts::*; use crate::cli::scripts::*;
use crate::cli::utils::regex_pred; use crate::common::fixtures::*;
use crate::common::utils::regex_pred;
#[rstest] #[rstest]
#[test_log::test] #[test_log::test]

@ -2,7 +2,7 @@ use std::env;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
#[rstest] #[rstest]
#[test_log::test] #[test_log::test]

@ -1,8 +1,8 @@
use distant_core::protocol::PROTOCOL_VERSION; use distant_core::protocol::PROTOCOL_VERSION;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
use crate::cli::utils::TrimmedLinesMatchPredicate; use crate::common::utils::TrimmedLinesMatchPredicate;
#[rstest] #[rstest]
#[test_log::test] #[test_log::test]

@ -1,7 +1,7 @@
use indoc::indoc; use indoc::indoc;
use rstest::*; use rstest::*;
use crate::cli::fixtures::*; use crate::common::fixtures::*;
const EXPECTED_TABLE: &str = indoc! {" const EXPECTED_TABLE: &str = indoc! {"
+---------------+--------------------------------------------------------------+ +---------------+--------------------------------------------------------------+

@ -1,6 +1,4 @@
mod api; mod api;
mod client; mod client;
mod fixtures;
mod manager; mod manager;
mod scripts; mod scripts;
mod utils;

@ -1 +1,2 @@
mod cli; mod cli;
mod common;

@ -0,0 +1,4 @@
#![allow(dead_code)]
pub mod fixtures;
pub mod utils;

@ -0,0 +1,51 @@
use assert_fs::prelude::*;
use rstest::*;
mod common;
use common::fixtures::*;
#[rstest]
#[test_log::test]
#[ignore]
fn should_handle_large_volume_of_requests(ctx: DistantManagerCtx) {
// Create a temporary directory to house a file we create and edit
// with a large volume of requests
let root = assert_fs::TempDir::new().unwrap();
// Establish a path to a file we will edit repeatedly
let path = root.child("file").to_path_buf();
// Perform many requests of writing a file and reading a file
for i in 1..100 {
ctx.new_assert_cmd(["fs", "write"])
.arg(path.to_str().unwrap())
.write_stdin(format!("idx: {i}"))
.assert();
ctx.new_assert_cmd(["fs", "read"])
.arg(path.to_str().unwrap())
.assert()
.stdout(format!("idx: {i}"));
}
}
#[rstest]
#[test_log::test]
#[ignore]
fn should_handle_wide_spread_of_clients(_ctx: DistantManagerCtx) {
todo!();
}
#[rstest]
#[test_log::test]
#[ignore]
fn should_handle_abrupt_client_disconnects(_ctx: DistantManagerCtx) {
todo!();
}
#[rstest]
#[test_log::test]
#[ignore]
fn should_handle_badly_killing_client_shell_with_interactive_process(_ctx: DistantManagerCtx) {
todo!();
}
Loading…
Cancel
Save