Fix server hangup (#206)

pull/207/head
Chip Senkbeil 11 months ago committed by GitHub
parent 8009cc9361
commit da75801639
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Request` and `Response` types from `distant-net` now support an optional
`Header` to send miscellaneous information
- New feature `tracing` provides https://github.com/tokio-rs/tracing support
as a new `--tracing` flag. Must be compiled with
`RUSTFLAGS="--cfg tokio_unstable"` to properly operate.
### Changed
@ -21,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DistantApi` now handles batch requests in parallel, returning the results in
order. To achieve the previous sequential processing of batch requests, the
header value `sequence` needs to be set to true
- Rename `GenericServerRef` to `ServerRef` and remove `ServerRef` trait,
refactoring `TcpServerRef`, `UnixSocketServerRef`, and `WindowsPipeServerRef`
to use the struct instead of `Box<dyn ServerRef>`
## [0.20.0-alpha.8]

@ -5,11 +5,13 @@ use async_trait::async_trait;
use crate::authenticator::Authenticator;
use crate::methods::AuthenticationMethod;
/// Authenticaton method for a static secret key
/// Authenticaton method that skips authentication and approves anything.
#[derive(Clone, Debug)]
pub struct NoneAuthenticationMethod;
impl NoneAuthenticationMethod {
pub const ID: &str = "none";
#[inline]
pub fn new() -> Self {
Self
@ -26,7 +28,7 @@ impl Default for NoneAuthenticationMethod {
#[async_trait]
impl AuthenticationMethod for NoneAuthenticationMethod {
fn id(&self) -> &'static str {
"none"
Self::ID
}
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {

@ -14,6 +14,8 @@ pub struct StaticKeyAuthenticationMethod<T> {
}
impl<T> StaticKeyAuthenticationMethod<T> {
pub const ID: &str = "static_key";
#[inline]
pub fn new(key: T) -> Self {
Self { key }
@ -26,7 +28,7 @@ where
T: FromStr + PartialEq + Send + Sync,
{
fn id(&self) -> &'static str {
"static_key"
Self::ID
}
async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()> {

@ -4,7 +4,7 @@ use std::sync::Arc;
use async_trait::async_trait;
use distant_net::common::ConnectionId;
use distant_net::server::{ConnectionCtx, Reply, ServerCtx, ServerHandler};
use distant_net::server::{Reply, RequestCtx, ServerHandler};
use log::*;
use crate::protocol::{
@ -16,23 +16,22 @@ mod reply;
use reply::DistantSingleReply;
/// Represents the context provided to the [`DistantApi`] for incoming requests
pub struct DistantCtx<T> {
pub struct DistantCtx {
pub connection_id: ConnectionId,
pub reply: Box<dyn Reply<Data = protocol::Response>>,
pub local_data: Arc<T>,
}
/// Represents a [`ServerHandler`] that leverages an API compliant with `distant`
pub struct DistantApiServerHandler<T, D>
pub struct DistantApiServerHandler<T>
where
T: DistantApi<LocalData = D>,
T: DistantApi,
{
api: Arc<T>,
}
impl<T, D> DistantApiServerHandler<T, D>
impl<T> DistantApiServerHandler<T>
where
T: DistantApi<LocalData = D>,
T: DistantApi,
{
pub fn new(api: T) -> Self {
Self { api: Arc::new(api) }
@ -51,12 +50,15 @@ fn unsupported<T>(label: &str) -> io::Result<T> {
/// which can be used to build other servers that are compatible with distant
#[async_trait]
pub trait DistantApi {
type LocalData: Send + Sync;
/// Invoked whenever a new connection is established.
#[allow(unused_variables)]
async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
/// Invoked whenever a new connection is established, providing a mutable reference to the
/// newly-created local data. This is a way to support modifying local data before it is used.
/// Invoked whenever an existing connection is dropped.
#[allow(unused_variables)]
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
@ -64,7 +66,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn version(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Version> {
async fn version(&self, ctx: DistantCtx) -> io::Result<Version> {
unsupported("version")
}
@ -74,11 +76,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn read_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
async fn read_file(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
unsupported("read_file")
}
@ -88,11 +86,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn read_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<String> {
async fn read_file_text(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<String> {
unsupported("read_file_text")
}
@ -103,12 +97,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn write_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
async fn write_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
unsupported("write_file")
}
@ -121,7 +110,7 @@ pub trait DistantApi {
#[allow(unused_variables)]
async fn write_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
data: String,
) -> io::Result<()> {
@ -135,12 +124,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn append_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
async fn append_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
unsupported("append_file")
}
@ -153,7 +137,7 @@ pub trait DistantApi {
#[allow(unused_variables)]
async fn append_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
data: String,
) -> io::Result<()> {
@ -172,7 +156,7 @@ pub trait DistantApi {
#[allow(unused_variables)]
async fn read_dir(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
depth: usize,
absolute: bool,
@ -189,12 +173,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn create_dir(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
all: bool,
) -> io::Result<()> {
async fn create_dir(&self, ctx: DistantCtx, path: PathBuf, all: bool) -> io::Result<()> {
unsupported("create_dir")
}
@ -205,12 +184,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn copy(
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
async fn copy(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
unsupported("copy")
}
@ -221,12 +195,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn remove(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
force: bool,
) -> io::Result<()> {
async fn remove(&self, ctx: DistantCtx, path: PathBuf, force: bool) -> io::Result<()> {
unsupported("remove")
}
@ -237,12 +206,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn rename(
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
async fn rename(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
unsupported("rename")
}
@ -257,7 +221,7 @@ pub trait DistantApi {
#[allow(unused_variables)]
async fn watch(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
recursive: bool,
only: Vec<ChangeKind>,
@ -272,7 +236,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn unwatch(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<()> {
async fn unwatch(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<()> {
unsupported("unwatch")
}
@ -282,7 +246,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn exists(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<bool> {
async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<bool> {
unsupported("exists")
}
@ -296,7 +260,7 @@ pub trait DistantApi {
#[allow(unused_variables)]
async fn metadata(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
canonicalize: bool,
resolve_file_type: bool,
@ -314,7 +278,7 @@ pub trait DistantApi {
#[allow(unused_variables)]
async fn set_permissions(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
permissions: Permissions,
options: SetPermissionsOptions,
@ -328,11 +292,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn search(
&self,
ctx: DistantCtx<Self::LocalData>,
query: SearchQuery,
) -> io::Result<SearchId> {
async fn search(&self, ctx: DistantCtx, query: SearchQuery) -> io::Result<SearchId> {
unsupported("search")
}
@ -342,11 +302,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn cancel_search(
&self,
ctx: DistantCtx<Self::LocalData>,
id: SearchId,
) -> io::Result<()> {
async fn cancel_search(&self, ctx: DistantCtx, id: SearchId) -> io::Result<()> {
unsupported("cancel_search")
}
@ -361,7 +317,7 @@ pub trait DistantApi {
#[allow(unused_variables)]
async fn proc_spawn(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
@ -376,7 +332,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn proc_kill(&self, ctx: DistantCtx<Self::LocalData>, id: ProcessId) -> io::Result<()> {
async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> {
unsupported("proc_kill")
}
@ -387,12 +343,7 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn proc_stdin(
&self,
ctx: DistantCtx<Self::LocalData>,
id: ProcessId,
data: Vec<u8>,
) -> io::Result<()> {
async fn proc_stdin(&self, ctx: DistantCtx, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
unsupported("proc_stdin")
}
@ -405,7 +356,7 @@ pub trait DistantApi {
#[allow(unused_variables)]
async fn proc_resize_pty(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
id: ProcessId,
size: PtySize,
) -> io::Result<()> {
@ -416,32 +367,34 @@ pub trait DistantApi {
///
/// *Override this, otherwise it will return "unsupported" as an error.*
#[allow(unused_variables)]
async fn system_info(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<SystemInfo> {
async fn system_info(&self, ctx: DistantCtx) -> io::Result<SystemInfo> {
unsupported("system_info")
}
}
#[async_trait]
impl<T, D> ServerHandler for DistantApiServerHandler<T, D>
impl<T> ServerHandler for DistantApiServerHandler<T>
where
T: DistantApi<LocalData = D> + Send + Sync + 'static,
D: Send + Sync + 'static,
T: DistantApi + Send + Sync + 'static,
{
type LocalData = D;
type Request = protocol::Msg<protocol::Request>;
type Response = protocol::Msg<protocol::Response>;
/// Overridden to leverage [`DistantApi`] implementation of `on_accept`
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
T::on_accept(&self.api, ctx).await
/// Overridden to leverage [`DistantApi`] implementation of `on_connect`.
async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
T::on_connect(&self.api, id).await
}
/// Overridden to leverage [`DistantApi`] implementation of `on_disconnect`.
async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
T::on_disconnect(&self.api, id).await
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
let ServerCtx {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
let RequestCtx {
connection_id,
request,
reply,
local_data,
} = ctx;
// Convert our reply to a queued reply so we can ensure that the result
@ -454,7 +407,6 @@ where
let ctx = DistantCtx {
connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data,
};
let data = handle_request(Arc::clone(&self.api), ctx, data).await;
@ -485,7 +437,6 @@ where
let ctx = DistantCtx {
connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data: Arc::clone(&local_data),
};
let data = handle_request(Arc::clone(&self.api), ctx, data).await;
@ -513,7 +464,6 @@ where
let ctx = DistantCtx {
connection_id,
reply: Box::new(DistantSingleReply::from(reply.clone_reply())),
local_data: Arc::clone(&local_data),
};
let task = tokio::spawn(async move {
@ -560,14 +510,13 @@ where
}
/// Processes an incoming request
async fn handle_request<T, D>(
async fn handle_request<T>(
api: Arc<T>,
ctx: DistantCtx<D>,
ctx: DistantCtx,
request: protocol::Request,
) -> protocol::Response
where
T: DistantApi<LocalData = D> + Send + Sync,
D: Send + Sync,
T: DistantApi + Send + Sync,
{
match request {
protocol::Request::Version {} => api

@ -11,9 +11,7 @@ use distant_net::common::{InmemoryTransport, OneshotListener};
use distant_net::server::{Server, ServerRef};
/// Stands up an inmemory client and server using the given api.
async fn setup(
api: impl DistantApi<LocalData = ()> + Send + Sync + 'static,
) -> (DistantClient, Box<dyn ServerRef>) {
async fn setup(api: impl DistantApi + Send + Sync + 'static) -> (DistantClient, ServerRef) {
let (t1, t2) = InmemoryTransport::pair(100);
let server = Server::new()
@ -33,22 +31,17 @@ async fn setup(
}
mod single {
use super::*;
use test_log::test;
use super::*;
#[test(tokio::test)]
async fn should_support_single_request_returning_error() {
struct TestDistantApi;
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
_path: PathBuf,
) -> io::Result<Vec<u8>> {
async fn read_file(&self, _ctx: DistantCtx, _path: PathBuf) -> io::Result<Vec<u8>> {
Err(io::Error::new(io::ErrorKind::NotFound, "test error"))
}
}
@ -66,13 +59,7 @@ mod single {
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
_path: PathBuf,
) -> io::Result<Vec<u8>> {
async fn read_file(&self, _ctx: DistantCtx, _path: PathBuf) -> io::Result<Vec<u8>> {
Ok(b"hello world".to_vec())
}
}
@ -85,25 +72,21 @@ mod single {
}
mod batch_parallel {
use super::*;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use distant_net::common::Request;
use distant_protocol::{Msg, Request as RequestPayload};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use test_log::test;
use super::*;
#[test(tokio::test)]
async fn should_support_multiple_requests_running_in_parallel() {
struct TestDistantApi;
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "slow" {
tokio::time::sleep(Duration::from_millis(500)).await;
}
@ -155,13 +138,7 @@ mod batch_parallel {
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "fail" {
return Err(io::Error::new(io::ErrorKind::Other, "test error"));
}
@ -211,25 +188,21 @@ mod batch_parallel {
}
mod batch_sequence {
use super::*;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use distant_net::common::Request;
use distant_protocol::{Msg, Request as RequestPayload};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use test_log::test;
use super::*;
#[test(tokio::test)]
async fn should_support_multiple_requests_running_in_sequence() {
struct TestDistantApi;
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "slow" {
tokio::time::sleep(Duration::from_millis(500)).await;
}
@ -284,13 +257,7 @@ mod batch_sequence {
#[async_trait]
impl DistantApi for TestDistantApi {
type LocalData = ();
async fn read_file(
&self,
_ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
async fn read_file(&self, _ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
if path.to_str().unwrap() == "fail" {
return Err(io::Error::new(io::ErrorKind::Other, "test error"));
}

@ -39,13 +39,7 @@ impl Api {
#[async_trait]
impl DistantApi for Api {
type LocalData = ();
async fn read_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
async fn read_file(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
debug!(
"[Conn {}] Reading bytes from file {:?}",
ctx.connection_id, path
@ -54,11 +48,7 @@ impl DistantApi for Api {
tokio::fs::read(path).await
}
async fn read_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<String> {
async fn read_file_text(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<String> {
debug!(
"[Conn {}] Reading text from file {:?}",
ctx.connection_id, path
@ -67,12 +57,7 @@ impl DistantApi for Api {
tokio::fs::read_to_string(path).await
}
async fn write_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
async fn write_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
debug!(
"[Conn {}] Writing bytes to file {:?}",
ctx.connection_id, path
@ -83,7 +68,7 @@ impl DistantApi for Api {
async fn write_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
data: String,
) -> io::Result<()> {
@ -95,12 +80,7 @@ impl DistantApi for Api {
tokio::fs::write(path, data).await
}
async fn append_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
async fn append_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
debug!(
"[Conn {}] Appending bytes to file {:?}",
ctx.connection_id, path
@ -116,7 +96,7 @@ impl DistantApi for Api {
async fn append_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
data: String,
) -> io::Result<()> {
@ -135,7 +115,7 @@ impl DistantApi for Api {
async fn read_dir(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
depth: usize,
absolute: bool,
@ -228,12 +208,7 @@ impl DistantApi for Api {
Ok((entries, errors))
}
async fn create_dir(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
all: bool,
) -> io::Result<()> {
async fn create_dir(&self, ctx: DistantCtx, path: PathBuf, all: bool) -> io::Result<()> {
debug!(
"[Conn {}] Creating directory {:?} {{all: {}}}",
ctx.connection_id, path, all
@ -245,12 +220,7 @@ impl DistantApi for Api {
}
}
async fn remove(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
force: bool,
) -> io::Result<()> {
async fn remove(&self, ctx: DistantCtx, path: PathBuf, force: bool) -> io::Result<()> {
debug!(
"[Conn {}] Removing {:?} {{force: {}}}",
ctx.connection_id, path, force
@ -267,12 +237,7 @@ impl DistantApi for Api {
}
}
async fn copy(
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
async fn copy(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
debug!(
"[Conn {}] Copying {:?} to {:?}",
ctx.connection_id, src, dst
@ -329,12 +294,7 @@ impl DistantApi for Api {
Ok(())
}
async fn rename(
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
async fn rename(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
debug!(
"[Conn {}] Renaming {:?} to {:?}",
ctx.connection_id, src, dst
@ -344,7 +304,7 @@ impl DistantApi for Api {
async fn watch(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
recursive: bool,
only: Vec<ChangeKind>,
@ -372,7 +332,7 @@ impl DistantApi for Api {
Ok(())
}
async fn unwatch(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<()> {
async fn unwatch(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<()> {
debug!("[Conn {}] Unwatching {:?}", ctx.connection_id, path);
self.state
@ -382,7 +342,7 @@ impl DistantApi for Api {
Ok(())
}
async fn exists(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<bool> {
async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<bool> {
debug!("[Conn {}] Checking if {:?} exists", ctx.connection_id, path);
// Following experimental `std::fs::try_exists`, which checks the error kind of the
@ -396,7 +356,7 @@ impl DistantApi for Api {
async fn metadata(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
canonicalize: bool,
resolve_file_type: bool,
@ -469,7 +429,7 @@ impl DistantApi for Api {
async fn set_permissions(
&self,
_ctx: DistantCtx<Self::LocalData>,
_ctx: DistantCtx,
path: PathBuf,
permissions: Permissions,
options: SetPermissionsOptions,
@ -596,11 +556,7 @@ impl DistantApi for Api {
}
}
async fn search(
&self,
ctx: DistantCtx<Self::LocalData>,
query: SearchQuery,
) -> io::Result<SearchId> {
async fn search(&self, ctx: DistantCtx, query: SearchQuery) -> io::Result<SearchId> {
debug!(
"[Conn {}] Performing search via {query:?}",
ctx.connection_id,
@ -609,11 +565,7 @@ impl DistantApi for Api {
self.state.search.start(query, ctx.reply).await
}
async fn cancel_search(
&self,
ctx: DistantCtx<Self::LocalData>,
id: SearchId,
) -> io::Result<()> {
async fn cancel_search(&self, ctx: DistantCtx, id: SearchId) -> io::Result<()> {
debug!("[Conn {}] Cancelling search {id}", ctx.connection_id,);
self.state.search.cancel(id).await
@ -621,7 +573,7 @@ impl DistantApi for Api {
async fn proc_spawn(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
@ -637,17 +589,12 @@ impl DistantApi for Api {
.await
}
async fn proc_kill(&self, ctx: DistantCtx<Self::LocalData>, id: ProcessId) -> io::Result<()> {
async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> {
debug!("[Conn {}] Killing process {}", ctx.connection_id, id);
self.state.process.kill(id).await
}
async fn proc_stdin(
&self,
ctx: DistantCtx<Self::LocalData>,
id: ProcessId,
data: Vec<u8>,
) -> io::Result<()> {
async fn proc_stdin(&self, ctx: DistantCtx, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
debug!(
"[Conn {}] Sending stdin to process {}",
ctx.connection_id, id
@ -657,7 +604,7 @@ impl DistantApi for Api {
async fn proc_resize_pty(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
id: ProcessId,
size: PtySize,
) -> io::Result<()> {
@ -668,7 +615,7 @@ impl DistantApi for Api {
self.state.process.resize_pty(id, size).await
}
async fn system_info(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<SystemInfo> {
async fn system_info(&self, ctx: DistantCtx) -> io::Result<SystemInfo> {
debug!("[Conn {}] Reading system information", ctx.connection_id);
Ok(SystemInfo {
family: env::consts::FAMILY.to_string(),
@ -685,7 +632,7 @@ impl DistantApi for Api {
})
}
async fn version(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Version> {
async fn version(&self, ctx: DistantCtx) -> io::Result<Version> {
debug!("[Conn {}] Querying version", ctx.connection_id);
Ok(Version {
@ -698,11 +645,10 @@ impl DistantApi for Api {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use assert_fs::prelude::*;
use distant_core::net::server::{ConnectionCtx, Reply};
use distant_core::net::server::Reply;
use distant_core::protocol::Response;
use once_cell::sync::Lazy;
use predicates::prelude::*;
@ -773,7 +719,7 @@ mod tests {
const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100);
async fn setup(buffer: usize) -> (Api, DistantCtx<()>, mpsc::Receiver<Response>) {
async fn setup(buffer: usize) -> (Api, DistantCtx, mpsc::Receiver<Response>) {
let api = Api::initialize(Config {
watch: WatchConfig {
debounce_timeout: DEBOUNCE_TIMEOUT,
@ -784,19 +730,10 @@ mod tests {
let (reply, rx) = make_reply(buffer);
let connection_id = rand::random();
DistantApi::on_accept(
&api,
ConnectionCtx {
connection_id,
local_data: &mut (),
},
)
.await
.unwrap();
DistantApi::on_connect(&api, connection_id).await.unwrap();
let ctx = DistantCtx {
connection_id,
reply,
local_data: Arc::new(()),
};
(api, ctx, rx)
}
@ -1683,7 +1620,6 @@ mod tests {
let ctx = DistantCtx {
connection_id: ctx_1.connection_id,
reply,
local_data: Arc::clone(&ctx_1.local_data),
};
(ctx, rx)
};
@ -2662,7 +2598,6 @@ mod tests {
let ctx = DistantCtx {
connection_id: ctx_1.connection_id,
reply,
local_data: Arc::clone(&ctx_1.local_data),
};
(ctx, rx)
};
@ -2723,7 +2658,6 @@ mod tests {
let ctx = DistantCtx {
connection_id: ctx_1.connection_id,
reply,
local_data: Arc::clone(&ctx_1.local_data),
};
(ctx, rx)
};

@ -9,10 +9,10 @@ mod config;
mod constants;
pub use api::Api;
pub use config::*;
use distant_core::{DistantApi, DistantApiServerHandler};
use distant_core::DistantApiServerHandler;
/// Implementation of [`DistantApiServerHandler`] using [`Api`].
pub type Handler = DistantApiServerHandler<Api, <Api as DistantApi>::LocalData>;
pub type Handler = DistantApiServerHandler<Api>;
/// Initializes a new [`Handler`].
pub fn new_handler(config: Config) -> std::io::Result<Handler> {

@ -216,9 +216,7 @@ impl UntypedClient {
// If we have flagged that a reconnect is needed, attempt to do so
if needs_reconnect {
info!("Client encountered issue, attempting to reconnect");
if log::log_enabled!(log::Level::Debug) {
debug!("Using strategy {reconnect_strategy:?}");
}
debug!("Using strategy {reconnect_strategy:?}");
match reconnect_strategy.reconnect(&mut connection).await {
Ok(()) => {
info!("Client successfully reconnected!");
@ -236,7 +234,7 @@ impl UntypedClient {
macro_rules! silence_needs_reconnect {
() => {{
debug!(
info!(
"Client exceeded {}s without server activity, so attempting to reconnect",
silence_duration.as_secs_f32(),
);
@ -260,7 +258,7 @@ impl UntypedClient {
let ready = tokio::select! {
// NOTE: This should NEVER return None as we never allow the channel to close.
cb = shutdown_rx.recv() => {
debug!("Client got shutdown signal, so exiting event loop");
info!("Client got shutdown signal, so exiting event loop");
let cb = cb.expect("Impossible: shutdown channel closed!");
let _ = cb.send(Ok(()));
watcher_tx.send_replace(ConnectionState::Disconnected);
@ -335,7 +333,7 @@ impl UntypedClient {
}
Ok(None) => {
debug!("Connection closed");
info!("Connection closed");
needs_reconnect = true;
watcher_tx.send_replace(ConnectionState::Reconnecting);
continue;

@ -3,13 +3,13 @@ mod request;
mod response;
mod value;
use std::io::Cursor;
pub use header::*;
pub use request::*;
pub use response::*;
pub use value::*;
use std::io::Cursor;
/// Represents a generic id type
pub type Id = String;
@ -257,9 +257,10 @@ mod tests {
use super::*;
mod read_str_bytes {
use super::*;
use test_log::test;
use super::*;
#[test]
fn should_fail_if_input_is_empty() {
let input = read_str_bytes(&[]).unwrap_err();
@ -282,9 +283,10 @@ mod tests {
}
mod read_key_eq {
use super::*;
use test_log::test;
use super::*;
#[test]
fn should_fail_if_input_is_empty() {
let input = read_key_eq(&[], "key").unwrap_err();
@ -338,9 +340,10 @@ mod tests {
}
mod read_header_bytes {
use super::*;
use test_log::test;
use super::*;
#[test]
fn should_fail_if_input_is_empty() {
let input = vec![];
@ -527,9 +530,10 @@ mod tests {
}
mod find_msgpack_byte_len {
use super::*;
use test_log::test;
use super::*;
#[test]
fn should_return_none_if_input_is_empty() {
let input = vec![];

@ -1,10 +1,12 @@
use crate::common::{utils, Value};
use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::{fmt, io};
use derive_more::IntoIterator;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io;
use std::ops::{Deref, DerefMut};
use crate::common::{utils, Value};
/// Generates a new [`Header`] of key/value pairs based on literals.
///
@ -90,3 +92,18 @@ impl DerefMut for Header {
&mut self.0
}
}
impl fmt::Display for Header {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{{")?;
for (key, value) in self.0.iter() {
let value = serde_json::to_string(value).unwrap_or_else(|_| String::from("--"));
write!(f, "\"{key}\" = {value}")?;
}
write!(f, "}}")?;
Ok(())
}
}

@ -1,10 +1,12 @@
use crate::common::utils;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::io;
use std::ops::{Deref, DerefMut};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use crate::common::utils;
/// Generic value type for data passed through header.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]

@ -12,7 +12,7 @@ use crate::manager::{
ConnectionInfo, ConnectionList, ManagerAuthenticationId, ManagerCapabilities, ManagerChannelId,
ManagerRequest, ManagerResponse,
};
use crate::server::{Server, ServerCtx, ServerHandler};
use crate::server::{RequestCtx, Server, ServerHandler};
mod authentication;
pub use authentication::*;
@ -31,6 +31,10 @@ pub struct ManagerServer {
/// Configuration settings for the server
config: Config,
/// Holds on to open channels feeding data back from a server to some connected client,
/// enabling us to cancel the tasks on demand
channels: RwLock<HashMap<ManagerChannelId, ManagerChannel>>,
/// Mapping of connection id -> connection
connections: RwLock<HashMap<ConnectionId, ManagerConnection>>,
@ -46,6 +50,7 @@ impl ManagerServer {
pub fn new(config: Config) -> Server<Self> {
Server::new().handler(Self {
config,
channels: RwLock::new(HashMap::new()),
connections: RwLock::new(HashMap::new()),
registry: Arc::new(RwLock::new(HashMap::new())),
})
@ -177,104 +182,120 @@ impl ManagerServer {
}
}
#[derive(Default)]
pub struct DistantManagerServerConnection {
/// Holds on to open channels feeding data back from a server to some connected client,
/// enabling us to cancel the tasks on demand
channels: RwLock<HashMap<ManagerChannelId, ManagerChannel>>,
}
#[async_trait]
impl ServerHandler for ManagerServer {
type LocalData = DistantManagerServerConnection;
type Request = ManagerRequest;
type Response = ManagerResponse;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
let ServerCtx {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
debug!("manager::on_request({ctx:?})");
let RequestCtx {
connection_id,
request,
reply,
local_data,
} = ctx;
let response = match request.payload {
ManagerRequest::Capabilities {} => match self.capabilities().await {
Ok(supported) => ManagerResponse::Capabilities { supported },
Err(x) => ManagerResponse::from(x),
},
ManagerRequest::Capabilities {} => {
debug!("Looking up capabilities");
match self.capabilities().await {
Ok(supported) => ManagerResponse::Capabilities { supported },
Err(x) => ManagerResponse::from(x),
}
}
ManagerRequest::Launch {
destination,
options,
} => match self
.launch(
*destination,
options,
ManagerAuthenticator {
reply: reply.clone(),
registry: Arc::clone(&self.registry),
},
)
.await
{
Ok(destination) => ManagerResponse::Launched { destination },
Err(x) => ManagerResponse::from(x),
},
} => {
info!("Launching {destination} with {options}");
match self
.launch(
*destination,
options,
ManagerAuthenticator {
reply: reply.clone(),
registry: Arc::clone(&self.registry),
},
)
.await
{
Ok(destination) => ManagerResponse::Launched { destination },
Err(x) => ManagerResponse::from(x),
}
}
ManagerRequest::Connect {
destination,
options,
} => match self
.connect(
*destination,
options,
ManagerAuthenticator {
reply: reply.clone(),
registry: Arc::clone(&self.registry),
},
)
.await
{
Ok(id) => ManagerResponse::Connected { id },
Err(x) => ManagerResponse::from(x),
},
} => {
info!("Connecting to {destination} with {options}");
match self
.connect(
*destination,
options,
ManagerAuthenticator {
reply: reply.clone(),
registry: Arc::clone(&self.registry),
},
)
.await
{
Ok(id) => ManagerResponse::Connected { id },
Err(x) => ManagerResponse::from(x),
}
}
ManagerRequest::Authenticate { id, msg } => {
trace!("Retrieving authentication callback registry");
match self.registry.write().await.remove(&id) {
Some(cb) => match cb.send(msg) {
Ok(_) => return,
Err(_) => ManagerResponse::Error {
description: "Unable to forward authentication callback".to_string(),
},
},
Some(cb) => {
trace!("Sending {msg:?} through authentication callback");
match cb.send(msg) {
Ok(_) => return,
Err(_) => ManagerResponse::Error {
description: "Unable to forward authentication callback"
.to_string(),
},
}
}
None => ManagerResponse::from(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid authentication id",
)),
}
}
ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) {
Some(connection) => match connection.open_channel(reply.clone()) {
Ok(channel) => {
debug!("[Conn {id}] Channel {} has been opened", channel.id());
let id = channel.id();
local_data.channels.write().await.insert(id, channel);
ManagerResponse::ChannelOpened { id }
ManagerRequest::OpenChannel { id } => {
debug!("Attempting to retrieve connection {id}");
match self.connections.read().await.get(&id) {
Some(connection) => {
debug!("Opening channel through connection {id}");
match connection.open_channel(reply.clone()) {
Ok(channel) => {
info!("[Conn {id}] Channel {} has been opened", channel.id());
let id = channel.id();
self.channels.write().await.insert(id, channel);
ManagerResponse::ChannelOpened { id }
}
Err(x) => ManagerResponse::from(x),
}
}
Err(x) => ManagerResponse::from(x),
},
None => ManagerResponse::from(io::Error::new(
io::ErrorKind::NotConnected,
"Connection does not exist",
)),
},
None => ManagerResponse::from(io::Error::new(
io::ErrorKind::NotConnected,
"Connection does not exist",
)),
}
}
ManagerRequest::Channel { id, request } => {
match local_data.channels.read().await.get(&id) {
debug!("Attempting to retrieve channel {id}");
match self.channels.read().await.get(&id) {
// TODO: For now, we are NOT sending back a response to acknowledge
// a successful channel send. We could do this in order for
// the client to listen for a complete send, but is it worth it?
Some(channel) => match channel.send(request) {
Ok(_) => return,
Err(x) => ManagerResponse::from(x),
},
Some(channel) => {
debug!("Sending {request:?} through channel {id}");
match channel.send(request) {
Ok(_) => return,
Err(x) => ManagerResponse::from(x),
}
}
None => ManagerResponse::from(io::Error::new(
io::ErrorKind::NotConnected,
"Channel is not open or does not exist",
@ -282,32 +303,54 @@ impl ServerHandler for ManagerServer {
}
}
ManagerRequest::CloseChannel { id } => {
match local_data.channels.write().await.remove(&id) {
Some(channel) => match channel.close() {
Ok(_) => {
debug!("Channel {id} has been closed");
ManagerResponse::ChannelClosed { id }
debug!("Attempting to remove channel {id}");
match self.channels.write().await.remove(&id) {
Some(channel) => {
debug!("Removed channel {}", channel.id());
match channel.close() {
Ok(_) => {
info!("Channel {id} has been closed");
ManagerResponse::ChannelClosed { id }
}
Err(x) => ManagerResponse::from(x),
}
Err(x) => ManagerResponse::from(x),
},
}
None => ManagerResponse::from(io::Error::new(
io::ErrorKind::NotConnected,
"Channel is not open or does not exist",
)),
}
}
ManagerRequest::Info { id } => match self.info(id).await {
Ok(info) => ManagerResponse::Info(info),
Err(x) => ManagerResponse::from(x),
},
ManagerRequest::List => match self.list().await {
Ok(list) => ManagerResponse::List(list),
Err(x) => ManagerResponse::from(x),
},
ManagerRequest::Kill { id } => match self.kill(id).await {
Ok(()) => ManagerResponse::Killed,
Err(x) => ManagerResponse::from(x),
},
ManagerRequest::Info { id } => {
debug!("Attempting to retrieve information for connection {id}");
match self.info(id).await {
Ok(info) => {
info!("Retrieved information for connection {id}");
ManagerResponse::Info(info)
}
Err(x) => ManagerResponse::from(x),
}
}
ManagerRequest::List => {
debug!("Attempting to retrieve the list of connections");
match self.list().await {
Ok(list) => {
info!("Retrieved list of connections");
ManagerResponse::List(list)
}
Err(x) => ManagerResponse::from(x),
}
}
ManagerRequest::Kill { id } => {
debug!("Attempting to kill connection {id}");
match self.kill(id).await {
Ok(()) => {
info!("Killed connection {id}");
ManagerResponse::Killed
}
Err(x) => ManagerResponse::from(x),
}
}
};
if let Err(x) = reply.send(response).await {
@ -356,6 +399,7 @@ mod tests {
let server = ManagerServer {
config,
channels: RwLock::new(HashMap::new()),
connections: RwLock::new(HashMap::new()),
registry,
};

@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::io;
use std::{fmt, io};
use log::*;
use tokio::sync::mpsc;
@ -135,6 +135,17 @@ enum Action {
},
}
impl fmt::Debug for Action {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Register { id, .. } => write!(f, "Action::Register {{ id: {id}, .. }}"),
Self::Unregister { id } => write!(f, "Action::Unregister {{ id: {id} }}"),
Self::Read { .. } => write!(f, "Action::Read {{ .. }}"),
Self::Write { id, .. } => write!(f, "Action::Write {{ id: {id}, .. }}"),
}
}
}
/// Internal task to process outgoing [`UntypedRequest`]s.
async fn request_task(
id: ConnectionId,
@ -142,10 +153,13 @@ async fn request_task(
mut rx: mpsc::UnboundedReceiver<UntypedRequest<'static>>,
) {
while let Some(req) = rx.recv().await {
trace!("[Conn {id}] Firing off request {}", req.id);
if let Err(x) = client.fire(req).await {
error!("[Conn {id}] Failed to send request: {x}");
}
}
trace!("[Conn {id}] Manager request task closed");
}
/// Internal task to process incoming [`UntypedResponse`]s.
@ -155,10 +169,17 @@ async fn response_task(
tx: mpsc::UnboundedSender<Action>,
) {
while let Some(res) = mailbox.next().await {
trace!(
"[Conn {id}] Receiving response {} to request {}",
res.id,
res.origin_id
);
if let Err(x) = tx.send(Action::Read { res }) {
error!("[Conn {id}] Failed to forward received response: {x}");
}
}
trace!("[Conn {id}] Manager response task closed");
}
/// Internal task to process [`Action`] items.
@ -174,6 +195,8 @@ async fn action_task(
let mut registered = HashMap::new();
while let Some(action) = rx.recv().await {
trace!("[Conn {id}] {action:?}");
match action {
Action::Register { id, reply } => {
registered.insert(id, reply);
@ -201,9 +224,20 @@ async fn action_task(
id: channel_id,
response: res,
};
if let Err(x) = reply.send(response).await {
error!("[Conn {id}] {x}");
}
// TODO: This seems to get stuck at times with some change recently,
// so we kick this off in a new task instead. The better solution
// is to switch most of our mpsc usage to be unbounded so we
// don't need an async call. The only bounded ones should be those
// externally facing to the API user, if even that.
//
// https://github.com/chipsenkbeil/distant/issues/205
let reply = reply.clone();
tokio::spawn(async move {
if let Err(x) = reply.send(response).await {
error!("[Conn {id}] {x}");
}
});
}
}
Action::Write { id, mut req } => {
@ -217,4 +251,6 @@ async fn action_task(
}
}
}
trace!("[Conn {id}] Manager action task closed");
}

@ -9,7 +9,7 @@ use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::sync::{broadcast, RwLock};
use crate::common::{Listener, Response, Transport};
use crate::common::{ConnectionId, Listener, Response, Transport};
mod builder;
pub use builder::*;
@ -56,23 +56,21 @@ pub trait ServerHandler: Send {
/// Type of data sent back by the server
type Response;
/// Type of data to store locally tied to the specific connection
type LocalData: Send;
/// Invoked upon a new connection becoming established.
///
/// ### Note
///
/// This can be useful in performing some additional initialization on the connection's local
/// data prior to it being used anywhere else.
#[allow(unused_variables)]
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
/// Invoked upon an existing connection getting dropped.
#[allow(unused_variables)]
async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
/// Invoked upon receiving a request from a client. The server should process this
/// request, which can be found in `ctx`, and send one or more replies in response.
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>);
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>);
}
impl Server<()> {
@ -144,11 +142,10 @@ where
T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{
/// Consumes the server, starting a task to process connections from the `listener` and
/// returning a [`ServerRef`] that can be used to control the active server instance.
pub fn start<L>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
pub fn start<L>(self, listener: L) -> io::Result<ServerRef>
where
L: Listener + 'static,
L::Output: Transport + 'static,
@ -157,7 +154,7 @@ where
let (tx, rx) = broadcast::channel(1);
let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
Ok(Box::new(GenericServerRef { shutdown: tx, task }))
Ok(ServerRef { shutdown: tx, task })
}
/// Internal task that is run to receive connections and spawn connection tasks
@ -226,6 +223,9 @@ where
.verifier(Arc::downgrade(&verifier))
.spawn(),
);
// Clean up current tasks being tracked
connection_tasks.retain(|task| !task.is_finished());
}
// Once we stop listening, we still want to wait until all connections have terminated
@ -257,15 +257,10 @@ mod tests {
#[async_trait]
impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = u16;
type Response = String;
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
Ok(())
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Always send back "hello"
ctx.reply.send("hello".to_string()).await.unwrap();
}

@ -42,7 +42,6 @@ where
T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{
pub async fn start<P>(self, addr: IpAddr, port: P) -> io::Result<TcpServerRef>
where
@ -66,17 +65,16 @@ mod tests {
use super::*;
use crate::client::Client;
use crate::common::Request;
use crate::server::ServerCtx;
use crate::server::RequestCtx;
pub struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = String;
type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Echo back what we received
ctx.reply
.send(ctx.request.payload.to_string())

@ -42,7 +42,6 @@ where
T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{
pub async fn start<P>(self, path: P) -> io::Result<UnixSocketServerRef>
where
@ -66,17 +65,16 @@ mod tests {
use super::*;
use crate::client::Client;
use crate::common::Request;
use crate::server::ServerCtx;
use crate::server::RequestCtx;
pub struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = String;
type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Echo back what we received
ctx.reply
.send(ctx.request.payload.to_string())

@ -42,7 +42,6 @@ where
T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{
/// Start a new server at the specified address using the given codec
pub async fn start<A>(self, addr: A) -> io::Result<WindowsPipeServerRef>
@ -77,17 +76,16 @@ mod tests {
use super::*;
use crate::client::Client;
use crate::common::Request;
use crate::server::ServerCtx;
use crate::server::RequestCtx;
pub struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = String;
type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Echo back what we received
ctx.reply
.send(ctx.request.payload.to_string())

@ -12,10 +12,7 @@ use serde::Serialize;
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::task::JoinHandle;
use super::{
ConnectionCtx, ConnectionState, ServerCtx, ServerHandler, ServerReply, ServerState,
ShutdownTimer,
};
use super::{ConnectionState, RequestCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer};
use crate::common::{
Backup, Connection, Frame, Interest, Keychain, Response, Transport, UntypedRequest,
};
@ -226,7 +223,6 @@ where
H: ServerHandler + Sync + 'static,
H::Request: DeserializeOwned + Send + Sync + 'static,
H::Response: Serialize + Send + 'static,
H::LocalData: Default + Send + Sync + 'static,
T: Transport + 'static,
{
pub fn spawn(self) -> ConnectionTask {
@ -429,16 +425,11 @@ where
let id = connection.id();
// Create local data for the connection and then process it
debug!("[Conn {id}] Officially accepting connection");
let mut local_data = H::LocalData::default();
if let Err(x) = await_or_shutdown!(handler.on_accept(ConnectionCtx {
connection_id: id,
local_data: &mut local_data
})) {
info!("[Conn {id}] Connection established");
if let Err(x) = await_or_shutdown!(handler.on_connect(id)) {
terminate_connection!(@fatal "[Conn {id}] Accepting connection failed: {x}");
}
let local_data = Arc::new(local_data);
let mut last_heartbeat = Instant::now();
// Restore our connection's channels if we have them, otherwise make new ones
@ -483,15 +474,22 @@ where
Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
Ok(request) => match request.to_typed_request() {
Ok(request) => {
if log::log_enabled!(Level::Debug) {
let debug_header = if !request.header.is_empty() {
format!(" | header {}", request.header)
} else {
String::new()
};
debug!("[Conn {id}] New request {}{debug_header}", request.id);
}
let origin_id = request.id.clone();
let ctx = ServerCtx {
let ctx = RequestCtx {
connection_id: id,
request,
reply: ServerReply {
origin_id,
tx: tx.clone(),
},
local_data: Arc::clone(&local_data),
};
// Spawn a new task to run the request handler so we don't block
@ -500,8 +498,8 @@ where
tokio::spawn(async move { handler.on_request(ctx).await });
}
Err(x) => {
if log::log_enabled!(Level::Trace) {
trace!(
if log::log_enabled!(Level::Debug) {
error!(
"[Conn {id}] Failed receiving {}",
String::from_utf8_lossy(&request.payload),
);
@ -600,21 +598,16 @@ mod tests {
use crate::common::{
HeapSecretKey, InmemoryTransport, Ready, Reconnectable, Request, Response,
};
use crate::server::Shutdown;
use crate::server::{ConnectionId, Shutdown};
struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = u16;
type Response = String;
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
Ok(())
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Always send back "hello"
ctx.reply.send("hello".to_string()).await.unwrap();
}
@ -735,18 +728,14 @@ mod tests {
#[async_trait]
impl ServerHandler for BadAcceptServerHandler {
type LocalData = ();
type Request = u16;
type Response = String;
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "bad accept"))
async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "bad connect"))
}
async fn on_request(
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
unreachable!();
}
}
@ -1027,20 +1016,16 @@ mod tests {
#[async_trait]
impl ServerHandler for HangingAcceptServerHandler {
type LocalData = ();
type Request = ();
type Response = ();
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
// Wait "forever" so we can ensure that we fail at this step
tokio::time::sleep(Duration::MAX).await;
Err(io::Error::new(io::ErrorKind::Other, "bad accept"))
Err(io::Error::new(io::ErrorKind::Other, "bad connect"))
}
async fn on_request(
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
unreachable!();
}
}
@ -1083,19 +1068,15 @@ mod tests {
#[async_trait]
impl ServerHandler for AcceptServerHandler {
type LocalData = ();
type Request = ();
type Response = ();
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
self.tx.send(()).await.unwrap();
Ok(())
}
async fn on_request(
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
unreachable!();
}
}

@ -1,28 +1,29 @@
use std::sync::Arc;
use std::fmt;
use super::ServerReply;
use crate::common::{ConnectionId, Request};
/// Represents contextual information for working with an inbound request
pub struct ServerCtx<T, U, D> {
/// Unique identifer associated with the connection that sent the request
/// Represents contextual information for working with an inbound request.
pub struct RequestCtx<T, U> {
/// Unique identifer associated with the connection that sent the request.
pub connection_id: ConnectionId,
/// The request being handled
/// The request being handled.
pub request: Request<T>,
/// Used to send replies back to be sent out by the server
/// Used to send replies back to be sent out by the server.
pub reply: ServerReply<U>,
/// Reference to the connection's local data
pub local_data: Arc<D>,
}
/// Represents contextual information for working with an inbound connection
pub struct ConnectionCtx<'a, D> {
/// Unique identifer associated with the connection
pub connection_id: ConnectionId,
/// Reference to the connection's local data
pub local_data: &'a mut D,
impl<T, U> fmt::Debug for RequestCtx<T, U>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RequestCtx")
.field("connection_id", &self.connection_id)
.field("request", &self.request)
.field("reply", &"...")
.finish()
}
}

@ -1,94 +1,27 @@
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::task::{JoinError, JoinHandle};
use crate::common::AsAny;
/// Interface to engage with a server instance.
pub trait ServerRef: AsAny + Send {
/// Returns true if the server is no longer running.
fn is_finished(&self) -> bool;
/// Sends a shutdown signal to the server.
fn shutdown(&self);
fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>>
where
Self: Sized + 'static,
{
Box::pin(async {
let task = tokio::spawn(async move {
while !self.is_finished() {
tokio::time::sleep(Duration::from_millis(100)).await;
}
});
task.await
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))
})
}
}
impl dyn ServerRef {
/// Attempts to convert this ref into a concrete ref by downcasting
pub fn as_server_ref<R: ServerRef>(&self) -> Option<&R> {
self.as_any().downcast_ref::<R>()
}
/// Attempts to convert this mutable ref into a concrete mutable ref by downcasting
pub fn as_mut_server_ref<R: ServerRef>(&mut self) -> Option<&mut R> {
self.as_mut_any().downcast_mut::<R>()
}
/// Attempts to convert this into a concrete, boxed ref by downcasting
pub fn into_boxed_server_ref<R: ServerRef>(
self: Box<Self>,
) -> Result<Box<R>, Box<dyn std::any::Any>> {
self.into_any().downcast::<R>()
}
/// Waits for the server to complete by continuously polling the finished state.
pub async fn polling_wait(&self) -> io::Result<()> {
while !self.is_finished() {
tokio::time::sleep(Duration::from_millis(100)).await;
}
Ok(())
}
}
/// Represents a generic reference to a server
pub struct GenericServerRef {
/// Represents a reference to a server
pub struct ServerRef {
pub(crate) shutdown: broadcast::Sender<()>,
pub(crate) task: JoinHandle<()>,
}
/// Runtime-specific implementation of [`ServerRef`] for a [`tokio::task::JoinHandle`]
impl ServerRef for GenericServerRef {
fn is_finished(&self) -> bool {
impl ServerRef {
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
fn shutdown(&self) {
pub fn shutdown(&self) {
let _ = self.shutdown.send(());
}
fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>>
where
Self: Sized + 'static,
{
Box::pin(async {
self.task
.await
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))
})
}
}
impl Future for GenericServerRef {
impl Future for ServerRef {
type Output = Result<(), JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {

@ -1,36 +1,59 @@
use std::future::Future;
use std::net::IpAddr;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::task::JoinError;
use super::ServerRef;
/// Reference to a TCP server instance
/// Reference to a TCP server instance.
pub struct TcpServerRef {
pub(crate) addr: IpAddr,
pub(crate) port: u16,
pub(crate) inner: Box<dyn ServerRef>,
pub(crate) inner: ServerRef,
}
impl TcpServerRef {
pub fn new(addr: IpAddr, port: u16, inner: Box<dyn ServerRef>) -> Self {
pub fn new(addr: IpAddr, port: u16, inner: ServerRef) -> Self {
Self { addr, port, inner }
}
/// Returns the IP address that the listener is bound to
/// Returns the IP address that the listener is bound to.
pub fn ip_addr(&self) -> IpAddr {
self.addr
}
/// Returns the port that the listener is bound to
/// Returns the port that the listener is bound to.
pub fn port(&self) -> u16 {
self.port
}
/// Consumes ref, returning inner ref.
pub fn into_inner(self) -> ServerRef {
self.inner
}
}
impl ServerRef for TcpServerRef {
fn is_finished(&self) -> bool {
self.inner.is_finished()
impl Future for TcpServerRef {
type Output = Result<(), JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner.task).poll(cx)
}
}
impl Deref for TcpServerRef {
type Target = ServerRef;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
fn shutdown(&self) {
self.inner.shutdown();
impl DerefMut for TcpServerRef {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}

@ -1,35 +1,53 @@
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::task::JoinError;
use super::ServerRef;
/// Reference to a unix socket server instance
/// Reference to a unix socket server instance.
pub struct UnixSocketServerRef {
pub(crate) path: PathBuf,
pub(crate) inner: Box<dyn ServerRef>,
pub(crate) inner: ServerRef,
}
impl UnixSocketServerRef {
pub fn new(path: PathBuf, inner: Box<dyn ServerRef>) -> Self {
pub fn new(path: PathBuf, inner: ServerRef) -> Self {
Self { path, inner }
}
/// Returns the path to the socket
/// Returns the path to the socket.
pub fn path(&self) -> &Path {
&self.path
}
/// Consumes ref, returning inner ref
pub fn into_inner(self) -> Box<dyn ServerRef> {
/// Consumes ref, returning inner ref.
pub fn into_inner(self) -> ServerRef {
self.inner
}
}
impl ServerRef for UnixSocketServerRef {
fn is_finished(&self) -> bool {
self.inner.is_finished()
impl Future for UnixSocketServerRef {
type Output = Result<(), JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner.task).poll(cx)
}
}
impl Deref for UnixSocketServerRef {
type Target = ServerRef;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
fn shutdown(&self) {
self.inner.shutdown();
impl DerefMut for UnixSocketServerRef {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}

@ -1,35 +1,53 @@
use std::ffi::{OsStr, OsString};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::task::JoinError;
use super::ServerRef;
/// Reference to a unix socket server instance
/// Reference to a windows pipe server instance.
pub struct WindowsPipeServerRef {
pub(crate) addr: OsString,
pub(crate) inner: Box<dyn ServerRef>,
pub(crate) inner: ServerRef,
}
impl WindowsPipeServerRef {
pub fn new(addr: OsString, inner: Box<dyn ServerRef>) -> Self {
pub fn new(addr: OsString, inner: ServerRef) -> Self {
Self { addr, inner }
}
/// Returns the addr that the listener is bound to
/// Returns the addr that the listener is bound to.
pub fn addr(&self) -> &OsStr {
&self.addr
}
/// Consumes ref, returning inner ref
pub fn into_inner(self) -> Box<dyn ServerRef> {
/// Consumes ref, returning inner ref.
pub fn into_inner(self) -> ServerRef {
self.inner
}
}
impl ServerRef for WindowsPipeServerRef {
fn is_finished(&self) -> bool {
self.inner.is_finished()
impl Future for WindowsPipeServerRef {
type Output = Result<(), JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner.task).poll(cx)
}
}
impl Deref for WindowsPipeServerRef {
type Target = ServerRef;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
fn shutdown(&self) {
self.inner.shutdown();
impl DerefMut for WindowsPipeServerRef {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}

@ -6,7 +6,7 @@ use distant_net::boxed_connect_handler;
use distant_net::client::Client;
use distant_net::common::{Destination, InmemoryTransport, Map, OneshotListener};
use distant_net::manager::{Config, ManagerClient, ManagerServer};
use distant_net::server::{Server, ServerCtx, ServerHandler};
use distant_net::server::{RequestCtx, Server, ServerHandler};
use log::*;
use test_log::test;
@ -14,11 +14,10 @@ struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = String;
type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
ctx.reply
.send(format!("echo {}", ctx.request.payload))
.await
@ -37,7 +36,7 @@ async fn should_be_able_to_establish_a_single_connection_and_communicate_with_a_
let (t1, t2) = InmemoryTransport::pair(100);
// Spawn a server on one end and connect to it on the other
let _ = Server::new()
let _server = Server::new()
.handler(TestServerHandler)
.verifier(Verifier::none())
.start(OneshotListener::from_value(t2))?;

@ -2,7 +2,7 @@ use async_trait::async_trait;
use distant_auth::{DummyAuthHandler, Verifier};
use distant_net::client::Client;
use distant_net::common::{InmemoryTransport, OneshotListener};
use distant_net::server::{Server, ServerCtx, ServerHandler};
use distant_net::server::{RequestCtx, Server, ServerHandler};
use log::*;
use test_log::test;
@ -10,11 +10,10 @@ struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = (u8, String);
type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
let (cnt, msg) = ctx.request.payload;
for i in 0..cnt {
@ -30,7 +29,7 @@ impl ServerHandler for TestServerHandler {
async fn should_be_able_to_send_and_receive_typed_payloads_between_client_and_server() {
let (t1, t2) = InmemoryTransport::pair(100);
let _ = Server::new()
let _server = Server::new()
.handler(TestServerHandler)
.verifier(Verifier::none())
.start(OneshotListener::from_value(t2))

@ -2,7 +2,7 @@ use async_trait::async_trait;
use distant_auth::{DummyAuthHandler, Verifier};
use distant_net::client::Client;
use distant_net::common::{InmemoryTransport, OneshotListener, Request};
use distant_net::server::{Server, ServerCtx, ServerHandler};
use distant_net::server::{RequestCtx, Server, ServerHandler};
use log::*;
use test_log::test;
@ -10,11 +10,10 @@ struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type LocalData = ();
type Request = (u8, String);
type Response = String;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
let (cnt, msg) = ctx.request.payload;
for i in 0..cnt {
@ -30,7 +29,7 @@ impl ServerHandler for TestServerHandler {
async fn should_be_able_to_send_and_receive_untyped_payloads_between_client_and_server() {
let (t1, t2) = InmemoryTransport::pair(100);
let _ = Server::new()
let _server = Server::new()
.handler(TestServerHandler)
.verifier(Verifier::none())
.start(OneshotListener::from_value(t2))

@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet, VecDeque};
use std::collections::{HashMap, VecDeque};
use std::io;
use std::path::PathBuf;
use std::sync::{Arc, Weak};
@ -7,7 +7,6 @@ use std::time::Duration;
use async_compat::CompatExt;
use async_once_cell::OnceCell;
use async_trait::async_trait;
use distant_core::net::server::ConnectionCtx;
use distant_core::protocol::{
Capabilities, CapabilityKind, DirEntry, Environment, FileType, Metadata, Permissions,
ProcessId, PtySize, SetPermissionsOptions, SystemInfo, UnixMetadata, Version, PROTOCOL_VERSION,
@ -25,16 +24,6 @@ use crate::utils::{self, to_other_error};
/// Time after copy completes to wait for stdout/stderr to close
const COPY_COMPLETE_TIMEOUT: Duration = Duration::from_secs(1);
#[derive(Default)]
pub struct ConnectionState {
/// List of process ids that will be killed when the connection terminates
processes: Arc<RwLock<HashSet<ProcessId>>>,
/// Internal reference to global process list for removals
/// NOTE: Initialized during `on_accept` of [`DistantApi`]
global_processes: Weak<RwLock<HashMap<ProcessId, Process>>>,
}
struct Process {
stdin_tx: mpsc::Sender<Vec<u8>>,
kill_tx: mpsc::Sender<()>,
@ -72,18 +61,7 @@ impl SshDistantApi {
#[async_trait]
impl DistantApi for SshDistantApi {
type LocalData = ConnectionState;
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
ctx.local_data.global_processes = Arc::downgrade(&self.processes);
Ok(())
}
async fn read_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<Vec<u8>> {
async fn read_file(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<Vec<u8>> {
debug!(
"[Conn {}] Reading bytes from file {:?}",
ctx.connection_id, path
@ -103,11 +81,7 @@ impl DistantApi for SshDistantApi {
Ok(contents.into_bytes())
}
async fn read_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
) -> io::Result<String> {
async fn read_file_text(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<String> {
debug!(
"[Conn {}] Reading text from file {:?}",
ctx.connection_id, path
@ -127,12 +101,7 @@ impl DistantApi for SshDistantApi {
Ok(contents)
}
async fn write_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
async fn write_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
debug!(
"[Conn {}] Writing bytes to file {:?}",
ctx.connection_id, path
@ -154,7 +123,7 @@ impl DistantApi for SshDistantApi {
async fn write_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
data: String,
) -> io::Result<()> {
@ -177,12 +146,7 @@ impl DistantApi for SshDistantApi {
Ok(())
}
async fn append_file(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
data: Vec<u8>,
) -> io::Result<()> {
async fn append_file(&self, ctx: DistantCtx, path: PathBuf, data: Vec<u8>) -> io::Result<()> {
debug!(
"[Conn {}] Appending bytes to file {:?}",
ctx.connection_id, path
@ -213,7 +177,7 @@ impl DistantApi for SshDistantApi {
async fn append_file_text(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
data: String,
) -> io::Result<()> {
@ -247,7 +211,7 @@ impl DistantApi for SshDistantApi {
async fn read_dir(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
depth: usize,
absolute: bool,
@ -375,12 +339,7 @@ impl DistantApi for SshDistantApi {
Ok((entries, errors))
}
async fn create_dir(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
all: bool,
) -> io::Result<()> {
async fn create_dir(&self, ctx: DistantCtx, path: PathBuf, all: bool) -> io::Result<()> {
debug!(
"[Conn {}] Creating directory {:?} {{all: {}}}",
ctx.connection_id, path, all
@ -436,12 +395,7 @@ impl DistantApi for SshDistantApi {
Ok(())
}
async fn remove(
&self,
ctx: DistantCtx<Self::LocalData>,
path: PathBuf,
force: bool,
) -> io::Result<()> {
async fn remove(&self, ctx: DistantCtx, path: PathBuf, force: bool) -> io::Result<()> {
debug!(
"[Conn {}] Removing {:?} {{force: {}}}",
ctx.connection_id, path, force
@ -526,12 +480,7 @@ impl DistantApi for SshDistantApi {
Ok(())
}
async fn copy(
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
async fn copy(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
debug!(
"[Conn {}] Copying {:?} to {:?}",
ctx.connection_id, src, dst
@ -573,12 +522,7 @@ impl DistantApi for SshDistantApi {
}
}
async fn rename(
&self,
ctx: DistantCtx<Self::LocalData>,
src: PathBuf,
dst: PathBuf,
) -> io::Result<()> {
async fn rename(&self, ctx: DistantCtx, src: PathBuf, dst: PathBuf) -> io::Result<()> {
debug!(
"[Conn {}] Renaming {:?} to {:?}",
ctx.connection_id, src, dst
@ -594,7 +538,7 @@ impl DistantApi for SshDistantApi {
Ok(())
}
async fn exists(&self, ctx: DistantCtx<Self::LocalData>, path: PathBuf) -> io::Result<bool> {
async fn exists(&self, ctx: DistantCtx, path: PathBuf) -> io::Result<bool> {
debug!("[Conn {}] Checking if {:?} exists", ctx.connection_id, path);
// NOTE: SFTP does not provide a means to check if a path exists that can be performed
@ -612,7 +556,7 @@ impl DistantApi for SshDistantApi {
async fn metadata(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
canonicalize: bool,
resolve_file_type: bool,
@ -676,7 +620,7 @@ impl DistantApi for SshDistantApi {
#[allow(unreachable_code)]
async fn set_permissions(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
path: PathBuf,
permissions: Permissions,
options: SetPermissionsOptions,
@ -805,7 +749,7 @@ impl DistantApi for SshDistantApi {
async fn proc_spawn(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
@ -817,14 +761,10 @@ impl DistantApi for SshDistantApi {
);
let global_processes = Arc::downgrade(&self.processes);
let local_processes = Arc::downgrade(&ctx.local_data.processes);
let cleanup = |id: ProcessId| async move {
if let Some(processes) = Weak::upgrade(&global_processes) {
processes.write().await.remove(&id);
}
if let Some(processes) = Weak::upgrade(&local_processes) {
processes.write().await.remove(&id);
}
};
let SpawnResult {
@ -874,7 +814,7 @@ impl DistantApi for SshDistantApi {
Ok(id)
}
async fn proc_kill(&self, ctx: DistantCtx<Self::LocalData>, id: ProcessId) -> io::Result<()> {
async fn proc_kill(&self, ctx: DistantCtx, id: ProcessId) -> io::Result<()> {
debug!("[Conn {}] Killing process {}", ctx.connection_id, id);
if let Some(process) = self.processes.read().await.get(&id) {
@ -892,12 +832,7 @@ impl DistantApi for SshDistantApi {
))
}
async fn proc_stdin(
&self,
ctx: DistantCtx<Self::LocalData>,
id: ProcessId,
data: Vec<u8>,
) -> io::Result<()> {
async fn proc_stdin(&self, ctx: DistantCtx, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
debug!(
"[Conn {}] Sending stdin to process {}",
ctx.connection_id, id
@ -920,7 +855,7 @@ impl DistantApi for SshDistantApi {
async fn proc_resize_pty(
&self,
ctx: DistantCtx<Self::LocalData>,
ctx: DistantCtx,
id: ProcessId,
size: PtySize,
) -> io::Result<()> {
@ -944,7 +879,7 @@ impl DistantApi for SshDistantApi {
))
}
async fn system_info(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<SystemInfo> {
async fn system_info(&self, ctx: DistantCtx) -> io::Result<SystemInfo> {
// We cache each of these requested values since they should not change for the
// lifetime of the ssh connection
static CURRENT_DIR: OnceCell<PathBuf> = OnceCell::new();
@ -998,7 +933,7 @@ impl DistantApi for SshDistantApi {
})
}
async fn version(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Version> {
async fn version(&self, ctx: DistantCtx) -> io::Result<Version> {
debug!("[Conn {}] Querying capabilities", ctx.connection_id);
let mut capabilities = Capabilities::all();

@ -722,7 +722,7 @@ impl Ssh {
}
/// Consumes [`Ssh`] and produces a [`DistantClient`] and [`ServerRef`] pair.
pub async fn into_distant_pair(self) -> io::Result<(DistantClient, Box<dyn ServerRef>)> {
pub async fn into_distant_pair(self) -> io::Result<(DistantClient, ServerRef)> {
// Exit early if not authenticated as this is a requirement
if !self.authenticated {
return Err(io::Error::new(

@ -13,7 +13,7 @@ pub(crate) use common::{Cache, Client, Manager};
/// Represents the primary CLI entrypoint
#[derive(Debug)]
pub struct Cli {
options: Options,
pub options: Options,
}
impl Cli {

@ -1,19 +1,16 @@
use std::collections::HashMap;
use std::io;
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::time::Duration;
use anyhow::Context;
use distant_core::net::common::{ConnectionId, Host, Map, Request, Response};
use distant_core::net::manager::ManagerClient;
use distant_core::protocol::SearchQueryContentsMatch;
use distant_core::protocol::SearchQueryMatch;
use distant_core::protocol::SearchQueryPathMatch;
use distant_core::protocol::{
self, Capabilities, ChangeKind, ChangeKindSet, FileType, Permissions, SearchQuery,
SetPermissionsOptions, SystemInfo,
SearchQueryContentsMatch, SearchQueryMatch, SearchQueryPathMatch, SetPermissionsOptions,
SystemInfo,
};
use distant_core::{DistantChannel, DistantChannelExt, RemoteCommand, Searcher, Watcher};
use log::*;

@ -185,7 +185,7 @@ async fn async_run(cmd: ManagerSubcommand) -> CliResult {
"global".to_string()
}
);
let manager_ref = Manager {
let manager = Manager {
access,
config: NetManagerConfig {
user,
@ -223,11 +223,7 @@ async fn async_run(cmd: ManagerSubcommand) -> CliResult {
.context("Failed to start manager")?;
// Let our server run to completion
manager_ref
.as_ref()
.polling_wait()
.await
.context("Failed to wait on manager")?;
manager.await.context("Failed to wait on manager")?;
info!("Manager is shutting down");
Ok(())

@ -3,7 +3,7 @@ use std::io::{self, Read, Write};
use anyhow::Context;
use distant_core::net::auth::Verifier;
use distant_core::net::common::{Host, SecretKey32};
use distant_core::net::server::{Server, ServerConfig as NetServerConfig, ServerRef};
use distant_core::net::server::{Server, ServerConfig as NetServerConfig};
use distant_core::DistantSingleKeyCredentials;
use distant_local::{Config as LocalConfig, WatchConfig as LocalWatchConfig};
use log::*;
@ -212,7 +212,7 @@ async fn async_run(cmd: ServerSubcommand, _is_forked: bool) -> CliResult {
}
// Let our server run to completion
server.wait().await.context("Failed to wait on server")?;
server.await.context("Failed to wait on server")?;
info!("Server is shutting down");
}
}

@ -15,7 +15,7 @@ pub struct Manager {
impl Manager {
/// Begin listening on the network interface specified within [`NetworkConfig`]
pub async fn listen(self) -> anyhow::Result<Box<dyn ServerRef>> {
pub async fn listen(self) -> anyhow::Result<ServerRef> {
let user = self.config.user;
#[cfg(unix)]
@ -36,7 +36,7 @@ impl Manager {
.with_context(|| format!("Failed to create socket directory {parent:?}"))?;
}
let boxed_ref = ManagerServer::new(self.config)
let server = ManagerServer::new(self.config)
.verifier(Verifier::none())
.start(
UnixSocketListener::bind_with_permissions(socket_path, self.access.into_mode())
@ -45,7 +45,7 @@ impl Manager {
.with_context(|| format!("Failed to start manager at socket {socket_path:?}"))?;
info!("Manager listening using unix socket @ {:?}", socket_path);
Ok(boxed_ref)
Ok(server)
}
#[cfg(windows)]
@ -57,13 +57,13 @@ impl Manager {
global_paths::WINDOWS_PIPE_NAME.as_str()
});
let boxed_ref = ManagerServer::new(self.config)
let server = ManagerServer::new(self.config)
.verifier(Verifier::none())
.start(WindowsPipeListener::bind_local(pipe_name)?)
.with_context(|| format!("Failed to start manager at pipe {pipe_name:?}"))?;
info!("Manager listening using windows pipe @ {:?}", pipe_name);
Ok(boxed_ref)
Ok(server)
}
}
}

@ -7,6 +7,7 @@ fn main() -> MainResult {
Err(x) => return MainResult::from(x),
};
let _logger = cli.init_logger();
MainResult::from(cli.run())
}

@ -28,6 +28,10 @@ pub struct Options {
#[clap(flatten)]
pub logging: LoggingSettings,
#[cfg(feature = "tracing")]
#[clap(long, global = true)]
pub tracing: bool,
/// Configuration file to load instead of the default paths
#[clap(short = 'c', long = "config", global = true, value_parser)]
config_path: Option<PathBuf>,

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test(tokio::test)]

@ -5,7 +5,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
/// Creates a directory in the form
///

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test(tokio::test)]

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::{json, Value};
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -2,8 +2,8 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::cli::scripts::*;
use crate::common::fixtures::*;
fn make_cmd(args: Vec<&str>) -> String {
format!(

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test(tokio::test)]

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test(tokio::test)]

@ -4,7 +4,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test(tokio::test)]

@ -3,7 +3,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test(tokio::test)]

@ -5,7 +5,7 @@ use rstest::*;
use serde_json::json;
use test_log::test;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
async fn wait_a_bit() {
wait_millis(250).await;

@ -2,7 +2,7 @@ use assert_fs::prelude::*;
use predicates::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -1,7 +1,7 @@
use assert_fs::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test_log::test]

@ -2,7 +2,7 @@ use assert_fs::prelude::*;
use predicates::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test_log::test]

@ -2,8 +2,8 @@ use assert_fs::prelude::*;
use predicates::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::cli::utils::regex_pred;
use crate::common::fixtures::*;
use crate::common::utils::regex_pred;
const FILE_CONTENTS: &str = r#"
some text

@ -4,8 +4,8 @@ use assert_fs::prelude::*;
use predicates::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::cli::utils::regex_pred;
use crate::common::fixtures::*;
use crate::common::utils::regex_pred;
/// Creates a directory in the form
///

@ -3,7 +3,7 @@ use indoc::indoc;
use predicates::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = indoc! {r#"
some text

@ -2,7 +2,7 @@ use assert_fs::prelude::*;
use predicates::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test_log::test]

@ -2,7 +2,7 @@ use assert_fs::prelude::*;
use predicates::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = r#"
some text

@ -3,7 +3,7 @@ use indoc::indoc;
use predicates::Predicate;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const SEARCH_RESULTS_REGEX: &str = indoc! {r"
.*?[\\/]file1.txt

@ -4,8 +4,8 @@ use std::time::Duration;
use assert_fs::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::cli::utils::ThreadedReader;
use crate::common::fixtures::*;
use crate::common::utils::ThreadedReader;
fn wait_a_bit() {
wait_millis(250);

@ -3,7 +3,7 @@ use indoc::indoc;
use predicates::prelude::*;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const FILE_CONTENTS: &str = indoc! {r#"
some text

@ -1,8 +1,8 @@
use rstest::*;
use crate::cli::fixtures::*;
use crate::cli::scripts::*;
use crate::cli::utils::regex_pred;
use crate::common::fixtures::*;
use crate::common::utils::regex_pred;
#[rstest]
#[test_log::test]

@ -2,7 +2,7 @@ use std::env;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
#[rstest]
#[test_log::test]

@ -1,8 +1,8 @@
use distant_core::protocol::PROTOCOL_VERSION;
use rstest::*;
use crate::cli::fixtures::*;
use crate::cli::utils::TrimmedLinesMatchPredicate;
use crate::common::fixtures::*;
use crate::common::utils::TrimmedLinesMatchPredicate;
#[rstest]
#[test_log::test]

@ -1,7 +1,7 @@
use indoc::indoc;
use rstest::*;
use crate::cli::fixtures::*;
use crate::common::fixtures::*;
const EXPECTED_TABLE: &str = indoc! {"
+---------------+--------------------------------------------------------------+

@ -1,6 +1,4 @@
mod api;
mod client;
mod fixtures;
mod manager;
mod scripts;
mod utils;

@ -1 +1,2 @@
mod cli;
mod common;

@ -0,0 +1,4 @@
#![allow(dead_code)]
pub mod fixtures;
pub mod utils;

@ -0,0 +1,51 @@
use assert_fs::prelude::*;
use rstest::*;
mod common;
use common::fixtures::*;
#[rstest]
#[test_log::test]
#[ignore]
fn should_handle_large_volume_of_requests(ctx: DistantManagerCtx) {
// Create a temporary directory to house a file we create and edit
// with a large volume of requests
let root = assert_fs::TempDir::new().unwrap();
// Establish a path to a file we will edit repeatedly
let path = root.child("file").to_path_buf();
// Perform many requests of writing a file and reading a file
for i in 1..100 {
ctx.new_assert_cmd(["fs", "write"])
.arg(path.to_str().unwrap())
.write_stdin(format!("idx: {i}"))
.assert();
ctx.new_assert_cmd(["fs", "read"])
.arg(path.to_str().unwrap())
.assert()
.stdout(format!("idx: {i}"));
}
}
#[rstest]
#[test_log::test]
#[ignore]
fn should_handle_wide_spread_of_clients(_ctx: DistantManagerCtx) {
todo!();
}
#[rstest]
#[test_log::test]
#[ignore]
fn should_handle_abrupt_client_disconnects(_ctx: DistantManagerCtx) {
todo!();
}
#[rstest]
#[test_log::test]
#[ignore]
fn should_handle_badly_killing_client_shell_with_interactive_process(_ctx: DistantManagerCtx) {
todo!();
}
Loading…
Cancel
Save