Add heartbeat support (#153)

* Update to support zero-size frame items

* Add heartbeat functionality with client reconnecting logic

* Fix connection reauthentication failures preventing future reauthentication

* More logging

* Remove persist

* Update connection logic to have server take on client id rather than having client take on server id during reconnect

* Bump minimum rust version to 1.64.0

* Bump to v0.20.0-alpha.3 and fix clippy warnings

* Update cargo.lock
pull/156/head v0.20.0-alpha.3
Chip Senkbeil 1 year ago committed by GitHub
parent ee595551ae
commit ee50eaf9b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -86,7 +86,7 @@ jobs:
- { rust: stable, os: windows-latest, target: x86_64-pc-windows-msvc } - { rust: stable, os: windows-latest, target: x86_64-pc-windows-msvc }
- { rust: stable, os: macos-latest } - { rust: stable, os: macos-latest }
- { rust: stable, os: ubuntu-latest } - { rust: stable, os: ubuntu-latest }
- { rust: 1.61.0, os: ubuntu-latest } - { rust: 1.64.0, os: ubuntu-latest }
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Install Rust ${{ matrix.rust }} - name: Install Rust ${{ matrix.rust }}

@ -7,6 +7,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.20.0-alpha.3]
### Added
- `Frame::empty` method as convenience for `Frame::new(&[])`
- `ClientConfig` to support `ReconnectStrategy` and a duration serving as the
maximum time to wait between server activity before attempting to reconnect
from the client
- Server sends empty frames periodically to act as heartbeats to let the client
know if the connection is still established
- Client now tracks length of time since last server activity and will attempt
a reconnect if no activity beyond that point
### Changed
- `Frame` methods `read` and `write` no longer return an `io::Result<...>`
and instead return `Option<Frame<...>>` and nothing respectively
- `Frame::read` method now supports zero-size items
- `Client::inmemory_spawn` and `UntypedClient::inmemory_spawn` now take a
`ClientConfig` as the second argument instead of `ReconnectStrategy`
- Persist option now removed from `ProcSpawn` message and CLI
- Bump minimum Rust version to 1.64.0
## [0.20.0-alpha.2] - 2022-11-20 ## [0.20.0-alpha.2] - 2022-11-20
### Added ### Added

593
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -3,7 +3,7 @@ name = "distant"
description = "Operate on a remote computer through file and process manipulation" description = "Operate on a remote computer through file and process manipulation"
categories = ["command-line-utilities"] categories = ["command-line-utilities"]
keywords = ["cli"] keywords = ["cli"]
version = "0.20.0-alpha.2" version = "0.20.0-alpha.3"
authors = ["Chip Senkbeil <chip@senkbeil.org>"] authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021" edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant" homepage = "https://github.com/chipsenkbeil/distant"
@ -32,7 +32,7 @@ clap_complete = "4.0.5"
config = { version = "0.13.2", default-features = false, features = ["toml"] } config = { version = "0.13.2", default-features = false, features = ["toml"] }
derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error", "is_variant"] } derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error", "is_variant"] }
dialoguer = { version = "0.10.2", default-features = false } dialoguer = { version = "0.10.2", default-features = false }
distant-core = { version = "=0.20.0-alpha.2", path = "distant-core", features = ["clap", "schemars"] } distant-core = { version = "=0.20.0-alpha.3", path = "distant-core", features = ["clap", "schemars"] }
directories = "4.0.1" directories = "4.0.1"
flexi_logger = "0.24.1" flexi_logger = "0.24.1"
indoc = "1.0.7" indoc = "1.0.7"
@ -54,7 +54,7 @@ winsplit = "0.1.0"
whoami = "1.2.3" whoami = "1.2.3"
# Optional native SSH functionality # Optional native SSH functionality
distant-ssh2 = { version = "=0.20.0-alpha.2", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true } distant-ssh2 = { version = "=0.20.0-alpha.3", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true }
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
fork = "0.1.20" fork = "0.1.20"

@ -3,7 +3,7 @@ name = "distant-core"
description = "Core library for distant, enabling operation on a remote computer through file and process manipulation" description = "Core library for distant, enabling operation on a remote computer through file and process manipulation"
categories = ["network-programming"] categories = ["network-programming"]
keywords = ["api", "async"] keywords = ["api", "async"]
version = "0.20.0-alpha.2" version = "0.20.0-alpha.3"
authors = ["Chip Senkbeil <chip@senkbeil.org>"] authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021" edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant" homepage = "https://github.com/chipsenkbeil/distant"
@ -19,7 +19,7 @@ async-trait = "0.1.58"
bitflags = "1.3.2" bitflags = "1.3.2"
bytes = "1.2.1" bytes = "1.2.1"
derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] } derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] }
distant-net = { version = "=0.20.0-alpha.2", path = "../distant-net" } distant-net = { version = "=0.20.0-alpha.3", path = "../distant-net" }
futures = "0.3.25" futures = "0.3.25"
grep = "0.2.10" grep = "0.2.10"
hex = "0.4.3" hex = "0.4.3"

@ -348,8 +348,6 @@ pub trait DistantApi {
/// * `cmd` - the full command to run as a new process (including arguments) /// * `cmd` - the full command to run as a new process (including arguments)
/// * `environment` - the environment variables to associate with the process /// * `environment` - the environment variables to associate with the process
/// * `current_dir` - the alternative current directory to use with the process /// * `current_dir` - the alternative current directory to use with the process
/// * `persist` - if true, the process will continue running even after the connection that
/// spawned the process has terminated
/// * `pty` - if provided, will run the process within a PTY of the given size /// * `pty` - if provided, will run the process within a PTY of the given size
/// ///
/// *Override this, otherwise it will return "unsupported" as an error.* /// *Override this, otherwise it will return "unsupported" as an error.*
@ -360,7 +358,6 @@ pub trait DistantApi {
cmd: String, cmd: String,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
) -> io::Result<ProcessId> { ) -> io::Result<ProcessId> {
unsupported("proc_spawn") unsupported("proc_spawn")
@ -650,11 +647,10 @@ where
cmd, cmd,
environment, environment,
current_dir, current_dir,
persist,
pty, pty,
} => server } => server
.api .api
.proc_spawn(ctx, cmd.into(), environment, current_dir, persist, pty) .proc_spawn(ctx, cmd.into(), environment, current_dir, pty)
.await .await
.map(|id| DistantResponseData::ProcSpawned { id }) .map(|id| DistantResponseData::ProcSpawned { id })
.unwrap_or_else(DistantResponseData::from), .unwrap_or_else(DistantResponseData::from),

@ -6,7 +6,6 @@ use crate::{
DistantApi, DistantCtx, DistantApi, DistantCtx,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use distant_net::server::ConnectionCtx;
use log::*; use log::*;
use std::{ use std::{
io, io,
@ -18,7 +17,6 @@ use walkdir::WalkDir;
mod process; mod process;
mod state; mod state;
pub use state::ConnectionState;
use state::*; use state::*;
/// Represents an implementation of [`DistantApi`] that works with the local machine /// Represents an implementation of [`DistantApi`] that works with the local machine
@ -40,14 +38,7 @@ impl LocalDistantApi {
#[async_trait] #[async_trait]
impl DistantApi for LocalDistantApi { impl DistantApi for LocalDistantApi {
type LocalData = ConnectionState; type LocalData = ();
/// Injects the global channels into the local connection
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
ctx.local_data.process_channel = self.state.process.clone_channel();
ctx.local_data.watcher_channel = self.state.watcher.clone_channel();
Ok(())
}
async fn capabilities(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Capabilities> { async fn capabilities(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Capabilities> {
debug!("[Conn {}] Querying capabilities", ctx.connection_id); debug!("[Conn {}] Querying capabilities", ctx.connection_id);
@ -451,16 +442,15 @@ impl DistantApi for LocalDistantApi {
cmd: String, cmd: String,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
) -> io::Result<ProcessId> { ) -> io::Result<ProcessId> {
debug!( debug!(
"[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, persist: {}, pty: {:?}}}", "[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, pty: {:?}}}",
ctx.connection_id, cmd, environment, current_dir, persist, pty ctx.connection_id, cmd, environment, current_dir, pty
); );
self.state self.state
.process .process
.spawn(cmd, environment, current_dir, persist, pty, ctx.reply) .spawn(cmd, environment, current_dir, pty, ctx.reply)
.await .await
} }
@ -504,6 +494,7 @@ impl DistantApi for LocalDistantApi {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::api::ConnectionCtx;
use crate::data::DistantResponseData; use crate::data::DistantResponseData;
use assert_fs::prelude::*; use assert_fs::prelude::*;
use distant_net::server::Reply; use distant_net::server::Reply;
@ -576,18 +567,18 @@ mod tests {
buffer: usize, buffer: usize,
) -> ( ) -> (
LocalDistantApi, LocalDistantApi,
DistantCtx<ConnectionState>, DistantCtx<()>,
mpsc::Receiver<DistantResponseData>, mpsc::Receiver<DistantResponseData>,
) { ) {
let api = LocalDistantApi::initialize().unwrap(); let api = LocalDistantApi::initialize().unwrap();
let (reply, rx) = make_reply(buffer); let (reply, rx) = make_reply(buffer);
let connection_id = rand::random(); let connection_id = rand::random();
let mut local_data = ConnectionState::default();
DistantApi::on_accept( DistantApi::on_accept(
&api, &api,
ConnectionCtx { ConnectionCtx {
connection_id, connection_id,
local_data: &mut local_data, local_data: &mut (),
}, },
) )
.await .await
@ -595,7 +586,7 @@ mod tests {
let ctx = DistantCtx { let ctx = DistantCtx {
connection_id, connection_id,
reply, reply,
local_data: Arc::new(local_data), local_data: Arc::new(()),
}; };
(api, ctx, rx) (api, ctx, rx)
} }
@ -1842,7 +1833,6 @@ mod tests {
/* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), /* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1867,7 +1857,6 @@ mod tests {
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1893,7 +1882,6 @@ mod tests {
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1958,7 +1946,6 @@ mod tests {
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -2019,7 +2006,6 @@ mod tests {
format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -2059,7 +2045,6 @@ mod tests {
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -2126,7 +2111,6 @@ mod tests {
), ),
Environment::new(), Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await

@ -1,6 +1,4 @@
use crate::data::{ProcessId, SearchId}; use std::io;
use distant_net::common::ConnectionId;
use std::{io, path::PathBuf};
mod process; mod process;
pub use process::*; pub use process::*;
@ -32,57 +30,3 @@ impl GlobalState {
}) })
} }
} }
/// Holds connection-specific state managed by the server
#[derive(Default)]
pub struct ConnectionState {
/// Unique id associated with connection
id: ConnectionId,
/// Channel connected to global process state
pub(crate) process_channel: ProcessChannel,
/// Channel connected to global search state
pub(crate) search_channel: SearchChannel,
/// Channel connected to global watcher state
pub(crate) watcher_channel: WatcherChannel,
/// Contains ids of processes that will be terminated when the connection is closed
processes: Vec<ProcessId>,
/// Contains paths being watched that will be unwatched when the connection is closed
paths: Vec<PathBuf>,
/// Contains ids of searches that will be terminated when the connection is closed
searches: Vec<SearchId>,
}
impl Drop for ConnectionState {
fn drop(&mut self) {
let id = self.id;
let processes: Vec<ProcessId> = self.processes.drain(..).collect();
let paths: Vec<PathBuf> = self.paths.drain(..).collect();
let searches: Vec<SearchId> = self.searches.drain(..).collect();
let process_channel = self.process_channel.clone();
let search_channel = self.search_channel.clone();
let watcher_channel = self.watcher_channel.clone();
// NOTE: We cannot (and should not) block during drop to perform cleanup,
// instead spawning a task that will do the cleanup async
tokio::spawn(async move {
for id in processes {
let _ = process_channel.kill(id).await;
}
for id in searches {
let _ = search_channel.cancel(id).await;
}
for path in paths {
let _ = watcher_channel.unwatch(id, path).await;
}
});
}
}

@ -9,14 +9,14 @@ use tokio::{
mod instance; mod instance;
pub use instance::*; pub use instance::*;
/// Holds information related to spawned processes on the server /// Holds information related to spawned processes on the server.
pub struct ProcessState { pub struct ProcessState {
channel: ProcessChannel, channel: ProcessChannel,
task: JoinHandle<()>, task: JoinHandle<()>,
} }
impl Drop for ProcessState { impl Drop for ProcessState {
/// Aborts the task that handles process operations and management /// Aborts the task that handles process operations and management.
fn drop(&mut self) { fn drop(&mut self) {
self.abort(); self.abort();
} }
@ -33,10 +33,6 @@ impl ProcessState {
} }
} }
pub fn clone_channel(&self) -> ProcessChannel {
self.channel.clone()
}
/// Aborts the process task /// Aborts the process task
pub fn abort(&self) { pub fn abort(&self) {
self.task.abort(); self.task.abort();
@ -57,7 +53,7 @@ pub struct ProcessChannel {
} }
impl Default for ProcessChannel { impl Default for ProcessChannel {
/// Creates a new channel that is closed by default /// Creates a new channel that is closed by default.
fn default() -> Self { fn default() -> Self {
let (tx, _) = mpsc::channel(1); let (tx, _) = mpsc::channel(1);
Self { tx } Self { tx }
@ -65,13 +61,12 @@ impl Default for ProcessChannel {
} }
impl ProcessChannel { impl ProcessChannel {
/// Spawns a new process, returning the id associated with it /// Spawns a new process, returning the id associated with it.
pub async fn spawn( pub async fn spawn(
&self, &self,
cmd: String, cmd: String,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
reply: Box<dyn Reply<Data = DistantResponseData>>, reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<ProcessId> { ) -> io::Result<ProcessId> {
@ -81,7 +76,6 @@ impl ProcessChannel {
cmd, cmd,
environment, environment,
current_dir, current_dir,
persist,
pty, pty,
reply, reply,
cb, cb,
@ -92,7 +86,7 @@ impl ProcessChannel {
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to spawn dropped"))? .map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to spawn dropped"))?
} }
/// Resizes the pty of a running process /// Resizes the pty of a running process.
pub async fn resize_pty(&self, id: ProcessId, size: PtySize) -> io::Result<()> { pub async fn resize_pty(&self, id: ProcessId, size: PtySize) -> io::Result<()> {
let (cb, rx) = oneshot::channel(); let (cb, rx) = oneshot::channel();
self.tx self.tx
@ -103,7 +97,7 @@ impl ProcessChannel {
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to resize dropped"))? .map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to resize dropped"))?
} }
/// Send stdin to a running process /// Send stdin to a running process.
pub async fn send_stdin(&self, id: ProcessId, data: Vec<u8>) -> io::Result<()> { pub async fn send_stdin(&self, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
let (cb, rx) = oneshot::channel(); let (cb, rx) = oneshot::channel();
self.tx self.tx
@ -114,7 +108,8 @@ impl ProcessChannel {
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to stdin dropped"))? .map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to stdin dropped"))?
} }
/// Kills a running process /// Kills a running process, including persistent processes if `force` is true. Will fail if
/// unable to kill the process or `force` is false when the process is persistent.
pub async fn kill(&self, id: ProcessId) -> io::Result<()> { pub async fn kill(&self, id: ProcessId) -> io::Result<()> {
let (cb, rx) = oneshot::channel(); let (cb, rx) = oneshot::channel();
self.tx self.tx
@ -126,13 +121,12 @@ impl ProcessChannel {
} }
} }
/// Internal message to pass to our task below to perform some action /// Internal message to pass to our task below to perform some action.
enum InnerProcessMsg { enum InnerProcessMsg {
Spawn { Spawn {
cmd: String, cmd: String,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
reply: Box<dyn Reply<Data = DistantResponseData>>, reply: Box<dyn Reply<Data = DistantResponseData>>,
cb: oneshot::Sender<io::Result<ProcessId>>, cb: oneshot::Sender<io::Result<ProcessId>>,
@ -165,14 +159,12 @@ async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<
cmd, cmd,
environment, environment,
current_dir, current_dir,
persist,
pty, pty,
reply, reply,
cb, cb,
} => { } => {
let _ = cb.send( let _ = cb.send(
match ProcessInstance::spawn(cmd, environment, current_dir, persist, pty, reply) match ProcessInstance::spawn(cmd, environment, current_dir, pty, reply) {
{
Ok(mut process) => { Ok(mut process) => {
let id = process.id; let id = process.id;
@ -195,7 +187,7 @@ async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<
Some(process) => process.pty.resize_pty(size), Some(process) => process.pty.resize_pty(size),
None => Err(io::Error::new( None => Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
format!("No process found with id {}", id), format!("No process found with id {id}"),
)), )),
}); });
} }
@ -205,12 +197,12 @@ async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<
Some(stdin) => stdin.send(&data).await, Some(stdin) => stdin.send(&data).await,
None => Err(io::Error::new( None => Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
format!("Process {} stdin is closed", id), format!("Process {id} stdin is closed"),
)), )),
}, },
None => Err(io::Error::new( None => Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
format!("No process found with id {}", id), format!("No process found with id {id}"),
)), )),
}); });
} }
@ -219,7 +211,7 @@ async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<
Some(process) => process.killer.kill().await, Some(process) => process.killer.kill().await,
None => Err(io::Error::new( None => Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
format!("No process found with id {}", id), format!("No process found with id {id}"),
)), )),
}); });
} }

@ -13,7 +13,6 @@ use tokio::task::JoinHandle;
pub struct ProcessInstance { pub struct ProcessInstance {
pub cmd: String, pub cmd: String,
pub args: Vec<String>, pub args: Vec<String>,
pub persist: bool,
pub id: ProcessId, pub id: ProcessId,
pub stdin: Option<Box<dyn InputChannel>>, pub stdin: Option<Box<dyn InputChannel>>,
@ -63,7 +62,6 @@ impl ProcessInstance {
cmd: String, cmd: String,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
reply: Box<dyn Reply<Data = DistantResponseData>>, reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<Self> { ) -> io::Result<Self> {
@ -135,7 +133,6 @@ impl ProcessInstance {
Ok(ProcessInstance { Ok(ProcessInstance {
cmd, cmd,
args, args,
persist,
id, id,
stdin, stdin,
killer, killer,

@ -95,10 +95,6 @@ impl WatcherState {
} }
} }
pub fn clone_channel(&self) -> WatcherChannel {
self.channel.clone()
}
/// Aborts the watcher task /// Aborts the watcher task
pub fn abort(&self) { pub fn abort(&self) {
self.task.abort(); self.task.abort();

@ -101,7 +101,6 @@ pub trait DistantChannelExt {
cmd: impl Into<String>, cmd: impl Into<String>,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteProcess>; ) -> AsyncReturn<'_, RemoteProcess>;
@ -111,7 +110,6 @@ pub trait DistantChannelExt {
cmd: impl Into<String>, cmd: impl Into<String>,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteLspProcess>; ) -> AsyncReturn<'_, RemoteLspProcess>;
@ -369,7 +367,6 @@ impl DistantChannelExt
cmd: impl Into<String>, cmd: impl Into<String>,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteProcess> { ) -> AsyncReturn<'_, RemoteProcess> {
let cmd = cmd.into(); let cmd = cmd.into();
@ -377,7 +374,6 @@ impl DistantChannelExt
RemoteCommand::new() RemoteCommand::new()
.environment(environment) .environment(environment)
.current_dir(current_dir) .current_dir(current_dir)
.persist(persist)
.pty(pty) .pty(pty)
.spawn(self.clone(), cmd) .spawn(self.clone(), cmd)
.await .await
@ -389,7 +385,6 @@ impl DistantChannelExt
cmd: impl Into<String>, cmd: impl Into<String>,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteLspProcess> { ) -> AsyncReturn<'_, RemoteLspProcess> {
let cmd = cmd.into(); let cmd = cmd.into();
@ -397,7 +392,6 @@ impl DistantChannelExt
RemoteLspCommand::new() RemoteLspCommand::new()
.environment(environment) .environment(environment)
.current_dir(current_dir) .current_dir(current_dir)
.persist(persist)
.pty(pty) .pty(pty)
.spawn(self.clone(), cmd) .spawn(self.clone(), cmd)
.await .await
@ -416,7 +410,6 @@ impl DistantChannelExt
RemoteCommand::new() RemoteCommand::new()
.environment(environment) .environment(environment)
.current_dir(current_dir) .current_dir(current_dir)
.persist(false)
.pty(pty) .pty(pty)
.spawn(self.clone(), cmd) .spawn(self.clone(), cmd)
.await? .await?

@ -22,7 +22,6 @@ pub use msg::*;
/// A [`RemoteLspProcess`] builder providing support to configure /// A [`RemoteLspProcess`] builder providing support to configure
/// before spawning the process on a remote machine /// before spawning the process on a remote machine
pub struct RemoteLspCommand { pub struct RemoteLspCommand {
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
@ -38,21 +37,12 @@ impl RemoteLspCommand {
/// Creates a new set of options for a remote LSP process /// Creates a new set of options for a remote LSP process
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
persist: false,
pty: None, pty: None,
environment: Environment::new(), environment: Environment::new(),
current_dir: None, current_dir: None,
} }
} }
/// Sets whether or not the process will be persistent,
/// meaning that it will not be terminated when the
/// connection to the remote machine is terminated
pub fn persist(&mut self, persist: bool) -> &mut Self {
self.persist = persist;
self
}
/// Configures the process to leverage a PTY with the specified size /// Configures the process to leverage a PTY with the specified size
pub fn pty(&mut self, pty: Option<PtySize>) -> &mut Self { pub fn pty(&mut self, pty: Option<PtySize>) -> &mut Self {
self.pty = pty; self.pty = pty;
@ -81,7 +71,6 @@ impl RemoteLspCommand {
let mut command = RemoteCommand::new(); let mut command = RemoteCommand::new();
command.environment(self.environment.clone()); command.environment(self.environment.clone());
command.current_dir(self.current_dir.clone()); command.current_dir(self.current_dir.clone());
command.persist(self.persist);
command.pty(self.pty); command.pty(self.pty);
let mut inner = command.spawn(channel, cmd).await?; let mut inner = command.spawn(channel, cmd).await?;
@ -412,7 +401,7 @@ mod tests {
use crate::data::{DistantRequestData, DistantResponseData}; use crate::data::{DistantRequestData, DistantResponseData};
use distant_net::{ use distant_net::{
common::{FramedTransport, InmemoryTransport, Request, Response}, common::{FramedTransport, InmemoryTransport, Request, Response},
Client, ReconnectStrategy, Client,
}; };
use std::{future::Future, time::Duration}; use std::{future::Future, time::Duration};
use test_log::test; use test_log::test;
@ -423,7 +412,7 @@ mod tests {
// Configures an lsp process with a means to send & receive data from outside // Configures an lsp process with a means to send & receive data from outside
async fn spawn_lsp_process() -> (FramedTransport<InmemoryTransport>, RemoteLspProcess) { async fn spawn_lsp_process() -> (FramedTransport<InmemoryTransport>, RemoteLspProcess) {
let (mut t1, t2) = FramedTransport::pair(100); let (mut t1, t2) = FramedTransport::pair(100);
let client = Client::spawn_inmemory(t2, ReconnectStrategy::Fail); let client = Client::spawn_inmemory(t2, Default::default());
let spawn_task = tokio::spawn({ let spawn_task = tokio::spawn({
let channel = client.clone_channel(); let channel = client.clone_channel();
async move { async move {

@ -47,7 +47,6 @@ type StatusResult = io::Result<RemoteStatus>;
/// A [`RemoteProcess`] builder providing support to configure /// A [`RemoteProcess`] builder providing support to configure
/// before spawning the process on a remote machine /// before spawning the process on a remote machine
pub struct RemoteCommand { pub struct RemoteCommand {
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
@ -63,21 +62,12 @@ impl RemoteCommand {
/// Creates a new set of options for a remote process /// Creates a new set of options for a remote process
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
persist: false,
pty: None, pty: None,
environment: Environment::new(), environment: Environment::new(),
current_dir: None, current_dir: None,
} }
} }
/// Sets whether or not the process will be persistent,
/// meaning that it will not be terminated when the
/// connection to the remote machine is terminated
pub fn persist(&mut self, persist: bool) -> &mut Self {
self.persist = persist;
self
}
/// Configures the process to leverage a PTY with the specified size /// Configures the process to leverage a PTY with the specified size
pub fn pty(&mut self, pty: Option<PtySize>) -> &mut Self { pub fn pty(&mut self, pty: Option<PtySize>) -> &mut Self {
self.pty = pty; self.pty = pty;
@ -109,7 +99,6 @@ impl RemoteCommand {
.mail(Request::new(DistantMsg::Single( .mail(Request::new(DistantMsg::Single(
DistantRequestData::ProcSpawn { DistantRequestData::ProcSpawn {
cmd: Cmd::from(cmd), cmd: Cmd::from(cmd),
persist: self.persist,
pty: self.pty, pty: self.pty,
environment: self.environment.clone(), environment: self.environment.clone(),
current_dir: self.current_dir.clone(), current_dir: self.current_dir.clone(),
@ -613,14 +602,14 @@ mod tests {
}; };
use distant_net::{ use distant_net::{
common::{FramedTransport, InmemoryTransport, Response}, common::{FramedTransport, InmemoryTransport, Response},
Client, ReconnectStrategy, Client,
}; };
use std::time::Duration; use std::time::Duration;
use test_log::test; use test_log::test;
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) { fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::pair(100); let (t1, t2) = FramedTransport::pair(100);
(t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail)) (t1, Client::spawn_inmemory(t2, Default::default()))
} }
#[test(tokio::test)] #[test(tokio::test)]

@ -198,7 +198,7 @@ mod tests {
use crate::DistantClient; use crate::DistantClient;
use distant_net::{ use distant_net::{
common::{FramedTransport, InmemoryTransport, Response}, common::{FramedTransport, InmemoryTransport, Response},
Client, ReconnectStrategy, Client,
}; };
use std::{path::PathBuf, sync::Arc}; use std::{path::PathBuf, sync::Arc};
use test_log::test; use test_log::test;
@ -206,7 +206,7 @@ mod tests {
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) { fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::pair(100); let (t1, t2) = FramedTransport::pair(100);
(t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail)) (t1, Client::spawn_inmemory(t2, Default::default()))
} }
#[test(tokio::test)] #[test(tokio::test)]

@ -186,7 +186,7 @@ mod tests {
use crate::DistantClient; use crate::DistantClient;
use distant_net::{ use distant_net::{
common::{FramedTransport, InmemoryTransport, Response}, common::{FramedTransport, InmemoryTransport, Response},
Client, ReconnectStrategy, Client,
}; };
use std::sync::Arc; use std::sync::Arc;
use test_log::test; use test_log::test;
@ -194,7 +194,7 @@ mod tests {
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) { fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::pair(100); let (t1, t2) = FramedTransport::pair(100);
(t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail)) (t1, Client::spawn_inmemory(t2, Default::default()))
} }
#[test(tokio::test)] #[test(tokio::test)]

@ -417,12 +417,6 @@ pub enum DistantRequestData {
#[cfg_attr(feature = "clap", clap(long))] #[cfg_attr(feature = "clap", clap(long))]
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
/// Whether or not the process should be persistent, meaning that the process will not be
/// killed when the associated client disconnects
#[serde(default)]
#[cfg_attr(feature = "clap", clap(long))]
persist: bool,
/// If provided, will spawn process in a pty, otherwise spawns directly /// If provided, will spawn process in a pty, otherwise spawns directly
#[serde(default)] #[serde(default)]
#[cfg_attr(feature = "clap", clap(long))] #[cfg_attr(feature = "clap", clap(long))]

@ -45,7 +45,7 @@ impl DistantClientCtx {
// Now initialize our client // Now initialize our client
let client: DistantClient = Client::build() let client: DistantClient = Client::build()
.auth_handler(DummyAuthHandler) .auth_handler(DummyAuthHandler)
.timeout(Duration::from_secs(1)) .connect_timeout(Duration::from_secs(1))
.connector(TcpConnector::new( .connector(TcpConnector::new(
format!("{}:{}", ip_addr, port) format!("{}:{}", ip_addr, port)
.parse::<SocketAddr>() .parse::<SocketAddr>()

@ -3,7 +3,7 @@ name = "distant-net"
description = "Network library for distant, providing implementations to support client/server architecture" description = "Network library for distant, providing implementations to support client/server architecture"
categories = ["network-programming"] categories = ["network-programming"]
keywords = ["api", "async"] keywords = ["api", "async"]
version = "0.20.0-alpha.2" version = "0.20.0-alpha.3"
authors = ["Chip Senkbeil <chip@senkbeil.org>"] authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021" edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant" homepage = "https://github.com/chipsenkbeil/distant"

@ -8,7 +8,7 @@ use std::{
fmt, io, fmt, io,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
sync::Arc, sync::Arc,
time::Duration, time::{Duration, Instant},
}; };
use tokio::{ use tokio::{
sync::{mpsc, oneshot, watch}, sync::{mpsc, oneshot, watch},
@ -21,6 +21,9 @@ pub use builder::*;
mod channel; mod channel;
pub use channel::*; pub use channel::*;
mod config;
pub use config::*;
mod reconnect; mod reconnect;
pub use reconnect::*; pub use reconnect::*;
@ -135,18 +138,18 @@ impl UntypedClient {
/// within a program. /// within a program.
pub fn spawn_inmemory( pub fn spawn_inmemory(
transport: FramedTransport<InmemoryTransport>, transport: FramedTransport<InmemoryTransport>,
strategy: ReconnectStrategy, config: ClientConfig,
) -> Self { ) -> Self {
let connection = Connection::Client { let connection = Connection::Client {
id: rand::random(), id: rand::random(),
reauth_otp: HeapSecretKey::generate(32).unwrap(), reauth_otp: HeapSecretKey::generate(32).unwrap(),
transport, transport,
}; };
Self::spawn(connection, strategy) Self::spawn(connection, config)
} }
/// Spawns a client using the provided [`Connection`]. /// Spawns a client using the provided [`Connection`].
pub(crate) fn spawn<V>(mut connection: Connection<V>, mut strategy: ReconnectStrategy) -> Self pub(crate) fn spawn<V>(mut connection: Connection<V>, config: ClientConfig) -> Self
where where
V: Transport + 'static, V: Transport + 'static,
{ {
@ -164,6 +167,7 @@ impl UntypedClient {
let (watcher_tx, watcher_rx) = watch::channel(ConnectionState::Connected); let (watcher_tx, watcher_rx) = watch::channel(ConnectionState::Connected);
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
let mut needs_reconnect = false; let mut needs_reconnect = false;
let mut last_read_frame_time = Instant::now();
// NOTE: We hold onto a copy of the shutdown sender, even though we will never use it, // NOTE: We hold onto a copy of the shutdown sender, even though we will never use it,
// to prevent the channel from being closed. This is because we do a check to // to prevent the channel from being closed. This is because we do a check to
@ -171,19 +175,24 @@ impl UntypedClient {
// would cause recv() to resolve immediately and result in the task shutting // would cause recv() to resolve immediately and result in the task shutting
// down. // down.
let _shutdown_tx = shutdown_tx_2; let _shutdown_tx = shutdown_tx_2;
let ClientConfig {
mut reconnect_strategy,
silence_duration,
} = config;
loop { loop {
// If we have flagged that a reconnect is needed, attempt to do so // If we have flagged that a reconnect is needed, attempt to do so
if needs_reconnect { if needs_reconnect {
info!("Client encountered issue, attempting to reconnect"); info!("Client encountered issue, attempting to reconnect");
if log::log_enabled!(log::Level::Debug) { if log::log_enabled!(log::Level::Debug) {
debug!("Using strategy {strategy:?}"); debug!("Using strategy {reconnect_strategy:?}");
} }
match strategy.reconnect(&mut connection).await { match reconnect_strategy.reconnect(&mut connection).await {
Ok(x) => { Ok(()) => {
info!("Client successfully reconnected!");
needs_reconnect = false; needs_reconnect = false;
last_read_frame_time = Instant::now();
watcher_tx.send_replace(ConnectionState::Connected); watcher_tx.send_replace(ConnectionState::Connected);
x
} }
Err(x) => { Err(x) => {
error!("Unable to re-establish connection: {x}"); error!("Unable to re-establish connection: {x}");
@ -193,6 +202,29 @@ impl UntypedClient {
} }
} }
macro_rules! silence_needs_reconnect {
() => {{
debug!(
"Client exceeded {}s without server activity, so attempting to reconnect",
silence_duration.as_secs_f32(),
);
needs_reconnect = true;
watcher_tx.send_replace(ConnectionState::Reconnecting);
continue;
}};
}
let silence_time_remaining = silence_duration
.checked_sub(last_read_frame_time.elapsed())
.unwrap_or_default();
// NOTE: sleep will not trigger if duration is zero/nanosecond scale, so we
// instead will do an early check here in the case that we need to reconnect
// prior to a sleep while polling the ready status
if silence_time_remaining.as_millis() == 0 {
silence_needs_reconnect!();
}
let ready = tokio::select! { let ready = tokio::select! {
// NOTE: This should NEVER return None as we never allow the channel to close. // NOTE: This should NEVER return None as we never allow the channel to close.
cb = shutdown_rx.recv() => { cb = shutdown_rx.recv() => {
@ -202,6 +234,9 @@ impl UntypedClient {
watcher_tx.send_replace(ConnectionState::Disconnected); watcher_tx.send_replace(ConnectionState::Disconnected);
break Ok(()); break Ok(());
} }
_ = tokio::time::sleep(silence_time_remaining) => {
silence_needs_reconnect!();
}
result = connection.ready(Interest::READABLE | Interest::WRITABLE) => { result = connection.ready(Interest::READABLE | Interest::WRITABLE) => {
match result { match result {
Ok(result) => result, Ok(result) => result,
@ -220,7 +255,16 @@ impl UntypedClient {
if ready.is_readable() { if ready.is_readable() {
match connection.try_read_frame() { match connection.try_read_frame() {
// If we get an empty frame, we consider this a heartbeat and want
// to adjust our frame read time and discard it from our backup
Ok(Some(frame)) if frame.is_empty() => {
trace!("Client received heartbeat");
last_read_frame_time = Instant::now();
}
// Otherwise, we attempt to parse a frame as a response
Ok(Some(frame)) => { Ok(Some(frame)) => {
last_read_frame_time = Instant::now();
match UntypedResponse::from_slice(frame.as_item()) { match UntypedResponse::from_slice(frame.as_item()) {
Ok(response) => { Ok(response) => {
if log_enabled!(Level::Trace) { if log_enabled!(Level::Trace) {
@ -242,6 +286,7 @@ impl UntypedClient {
} }
} }
} }
Ok(None) => { Ok(None) => {
debug!("Connection closed"); debug!("Connection closed");
needs_reconnect = true; needs_reconnect = true;
@ -391,9 +436,9 @@ where
/// within a program. /// within a program.
pub fn spawn_inmemory( pub fn spawn_inmemory(
transport: FramedTransport<InmemoryTransport>, transport: FramedTransport<InmemoryTransport>,
strategy: ReconnectStrategy, config: ClientConfig,
) -> Self { ) -> Self {
UntypedClient::spawn_inmemory(transport, strategy).into_typed_client() UntypedClient::spawn_inmemory(transport, config).into_typed_client()
} }
} }
@ -515,6 +560,7 @@ impl<T, U> From<Client<T, U>> for Channel<T, U> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::client::ClientConfig;
use crate::common::{Ready, Request, Response, TestTransport}; use crate::common::{Ready, Request, Response, TestTransport};
mod typed { mod typed {
@ -524,12 +570,19 @@ mod tests {
fn spawn_test_client<T>( fn spawn_test_client<T>(
connection: Connection<T>, connection: Connection<T>,
strategy: ReconnectStrategy, reconnect_strategy: ReconnectStrategy,
) -> TestClient ) -> TestClient
where where
T: Transport + 'static, T: Transport + 'static,
{ {
UntypedClient::spawn(connection, strategy).into_typed_client() UntypedClient::spawn(
connection,
ClientConfig {
reconnect_strategy,
..Default::default()
},
)
.into_typed_client()
} }
/// Creates a new test transport whose operations do not panic, but do nothing. /// Creates a new test transport whose operations do not panic, but do nothing.
@ -848,7 +901,7 @@ mod tests {
async fn should_write_queued_requests_as_outgoing_frames() { async fn should_write_queued_requests_as_outgoing_frames() {
let (client, mut server) = Connection::pair(100); let (client, mut server) = Connection::pair(100);
let mut client = TestClient::spawn(client, ReconnectStrategy::Fail); let mut client = TestClient::spawn(client, Default::default());
client client
.fire(Request::new(1u8).to_untyped_request().unwrap()) .fire(Request::new(1u8).to_untyped_request().unwrap())
.await .await
@ -908,7 +961,7 @@ mod tests {
.unwrap(); .unwrap();
}); });
let mut client = TestClient::spawn(client, ReconnectStrategy::Fail); let mut client = TestClient::spawn(client, Default::default());
assert_eq!( assert_eq!(
client client
.send(Request::new(1u8).to_untyped_request().unwrap()) .send(Request::new(1u8).to_untyped_request().unwrap())
@ -938,10 +991,13 @@ mod tests {
transport transport
}), }),
ReconnectStrategy::FixedInterval { ClientConfig {
interval: Duration::from_millis(50), reconnect_strategy: ReconnectStrategy::FixedInterval {
max_retries: None, interval: Duration::from_millis(50),
timeout: None, max_retries: None,
timeout: None,
},
..Default::default()
}, },
); );
@ -969,10 +1025,13 @@ mod tests {
transport transport
}), }),
ReconnectStrategy::FixedInterval { ClientConfig {
interval: Duration::from_millis(50), reconnect_strategy: ReconnectStrategy::FixedInterval {
max_retries: None, interval: Duration::from_millis(50),
timeout: None, max_retries: None,
timeout: None,
},
..Default::default()
}, },
); );
@ -1000,10 +1059,13 @@ mod tests {
transport transport
}), }),
ReconnectStrategy::FixedInterval { ClientConfig {
interval: Duration::from_millis(50), reconnect_strategy: ReconnectStrategy::FixedInterval {
max_retries: None, interval: Duration::from_millis(50),
timeout: None, max_retries: None,
timeout: None,
},
..Default::default()
}, },
); );
@ -1031,10 +1093,13 @@ mod tests {
transport transport
}), }),
ReconnectStrategy::FixedInterval { ClientConfig {
interval: Duration::from_millis(50), reconnect_strategy: ReconnectStrategy::FixedInterval {
max_retries: None, interval: Duration::from_millis(50),
timeout: None, max_retries: None,
timeout: None,
},
..Default::default()
}, },
); );
@ -1079,10 +1144,13 @@ mod tests {
transport transport
}), }),
ReconnectStrategy::FixedInterval { ClientConfig {
interval: Duration::from_millis(50), reconnect_strategy: ReconnectStrategy::FixedInterval {
max_retries: None, interval: Duration::from_millis(50),
timeout: None, max_retries: None,
timeout: None,
},
..Default::default()
}, },
); );
@ -1101,7 +1169,7 @@ mod tests {
// Spawn the client, verify the task is running, kill our server, and verify that the // Spawn the client, verify the task is running, kill our server, and verify that the
// client does not block trying to reconnect // client does not block trying to reconnect
let client = TestClient::spawn(client, ReconnectStrategy::Fail); let client = TestClient::spawn(client, Default::default());
assert!(!client.is_finished(), "Client unexpectedly died"); assert!(!client.is_finished(), "Client unexpectedly died");
drop(server); drop(server);
assert_eq!( assert_eq!(
@ -1114,7 +1182,7 @@ mod tests {
async fn should_exit_if_shutdown_signal_detected() { async fn should_exit_if_shutdown_signal_detected() {
let (client, _server) = Connection::pair(100); let (client, _server) = Connection::pair(100);
let client = TestClient::spawn(client, ReconnectStrategy::Fail); let client = TestClient::spawn(client, Default::default());
client.shutdown().await.unwrap(); client.shutdown().await.unwrap();
// NOTE: We wait for the client's task to conclude by using `wait` to ensure we do not // NOTE: We wait for the client's task to conclude by using `wait` to ensure we do not
@ -1142,7 +1210,7 @@ mod tests {
// NOTE: We consume the client to produce a channel without maintaining the shutdown // NOTE: We consume the client to produce a channel without maintaining the shutdown
// channel in order to ensure that dropping the client does not kill the task. // channel in order to ensure that dropping the client does not kill the task.
let mut channel = TestClient::spawn(client, ReconnectStrategy::Fail).into_channel(); let mut channel = TestClient::spawn(client, Default::default()).into_channel();
assert_eq!( assert_eq!(
channel channel
.send(Request::new(1u8).to_untyped_request().unwrap()) .send(Request::new(1u8).to_untyped_request().unwrap())
@ -1154,5 +1222,30 @@ mod tests {
2 2
); );
} }
#[test(tokio::test)]
async fn should_attempt_to_reconnect_if_no_activity_from_server_within_silence_duration() {
let (client, _) = Connection::pair(100);
// NOTE: We consume the client to produce a channel without maintaining the shutdown
// channel in order to ensure that dropping the client does not kill the task.
let client = TestClient::spawn(
client,
ClientConfig {
silence_duration: Duration::from_millis(100),
reconnect_strategy: ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: Some(3),
timeout: None,
},
},
);
let (tx, mut rx) = mpsc::unbounded_channel();
client.on_connection_change(move |state| tx.send(state).unwrap());
assert_eq!(rx.recv().await, Some(ConnectionState::Reconnecting));
assert_eq!(rx.recv().await, Some(ConnectionState::Disconnected));
assert_eq!(rx.recv().await, None);
}
} }
} }

@ -13,7 +13,8 @@ mod windows;
#[cfg(windows)] #[cfg(windows)]
pub use windows::*; pub use windows::*;
use crate::client::{Client, ReconnectStrategy, UntypedClient}; use super::ClientConfig;
use crate::client::{Client, UntypedClient};
use crate::common::{authentication::AuthHandler, Connection, Transport}; use crate::common::{authentication::AuthHandler, Connection, Transport};
use async_trait::async_trait; use async_trait::async_trait;
use std::{convert, io, time::Duration}; use std::{convert, io, time::Duration};
@ -40,44 +41,48 @@ impl<T: Transport + 'static> Connector for T {
pub struct ClientBuilder<H, C> { pub struct ClientBuilder<H, C> {
auth_handler: H, auth_handler: H,
connector: C, connector: C,
reconnect_strategy: ReconnectStrategy, config: ClientConfig,
timeout: Option<Duration>, connect_timeout: Option<Duration>,
} }
impl<H, C> ClientBuilder<H, C> { impl<H, C> ClientBuilder<H, C> {
/// Configure the authentication handler to use when connecting to a server.
pub fn auth_handler<U>(self, auth_handler: U) -> ClientBuilder<U, C> { pub fn auth_handler<U>(self, auth_handler: U) -> ClientBuilder<U, C> {
ClientBuilder { ClientBuilder {
auth_handler, auth_handler,
config: self.config,
connector: self.connector, connector: self.connector,
reconnect_strategy: self.reconnect_strategy, connect_timeout: self.connect_timeout,
timeout: self.timeout,
} }
} }
pub fn connector<U>(self, connector: U) -> ClientBuilder<H, U> { /// Configure the client-local configuration details.
ClientBuilder { pub fn config(self, config: ClientConfig) -> Self {
Self {
auth_handler: self.auth_handler, auth_handler: self.auth_handler,
connector, config,
reconnect_strategy: self.reconnect_strategy, connector: self.connector,
timeout: self.timeout, connect_timeout: self.connect_timeout,
} }
} }
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> ClientBuilder<H, C> { /// Configure the connector to use to facilitate connecting to a server.
pub fn connector<U>(self, connector: U) -> ClientBuilder<H, U> {
ClientBuilder { ClientBuilder {
auth_handler: self.auth_handler, auth_handler: self.auth_handler,
connector: self.connector, config: self.config,
reconnect_strategy, connector,
timeout: self.timeout, connect_timeout: self.connect_timeout,
} }
} }
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self { /// Configure a maximum duration to wait for a connection to a server to complete.
pub fn connect_timeout(self, connect_timeout: impl Into<Option<Duration>>) -> Self {
Self { Self {
auth_handler: self.auth_handler, auth_handler: self.auth_handler,
config: self.config,
connector: self.connector, connector: self.connector,
reconnect_strategy: self.reconnect_strategy, connect_timeout: connect_timeout.into(),
timeout: timeout.into(),
} }
} }
} }
@ -86,9 +91,9 @@ impl ClientBuilder<(), ()> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
auth_handler: (), auth_handler: (),
reconnect_strategy: ReconnectStrategy::default(), config: Default::default(),
connector: (), connector: (),
timeout: None, connect_timeout: None,
} }
} }
} }
@ -109,11 +114,11 @@ where
/// is fully established and authenticated. /// is fully established and authenticated.
pub async fn connect_untyped(self) -> io::Result<UntypedClient> { pub async fn connect_untyped(self) -> io::Result<UntypedClient> {
let auth_handler = self.auth_handler; let auth_handler = self.auth_handler;
let retry_strategy = self.reconnect_strategy; let config = self.config;
let timeout = self.timeout; let connect_timeout = self.connect_timeout;
let f = async move { let f = async move {
let transport = match timeout { let transport = match connect_timeout {
Some(duration) => tokio::time::timeout(duration, self.connector.connect()) Some(duration) => tokio::time::timeout(duration, self.connector.connect())
.await .await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
@ -121,10 +126,10 @@ where
None => self.connector.connect().await?, None => self.connector.connect().await?,
}; };
let connection = Connection::client(transport, auth_handler).await?; let connection = Connection::client(transport, auth_handler).await?;
Ok(UntypedClient::spawn(connection, retry_strategy)) Ok(UntypedClient::spawn(connection, config))
}; };
match timeout { match connect_timeout {
Some(duration) => tokio::time::timeout(duration, f) Some(duration) => tokio::time::timeout(duration, f)
.await .await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))

@ -0,0 +1,34 @@
use super::ReconnectStrategy;
use std::time::Duration;
const DEFAULT_SILENCE_DURATION: Duration = Duration::from_secs(20);
const MAXIMUM_SILENCE_DURATION: Duration = Duration::from_millis(68719476734);
/// Represents a general-purpose set of properties tied with a client instance.
#[derive(Clone, Debug)]
pub struct ClientConfig {
/// Strategy to use when reconnecting to a server.
pub reconnect_strategy: ReconnectStrategy,
/// A maximum duration to not receive any response/heartbeat from a server before deeming the
/// server as lost and triggering a reconnect.
pub silence_duration: Duration,
}
impl ClientConfig {
pub fn with_maximum_silence_duration(self) -> Self {
Self {
reconnect_strategy: self.reconnect_strategy,
silence_duration: MAXIMUM_SILENCE_DURATION,
}
}
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
reconnect_strategy: ReconnectStrategy::Fail,
silence_duration: DEFAULT_SILENCE_DURATION,
}
}
}

@ -1,4 +1,5 @@
use super::Reconnectable; use super::Reconnectable;
use log::*;
use std::io; use std::io;
use std::time::Duration; use std::time::Duration;
use strum::Display; use strum::Display;
@ -170,8 +171,11 @@ impl ReconnectStrategy {
}; };
// If reconnect was successful, we're done and we can exit // If reconnect was successful, we're done and we can exit
if result.is_ok() { match &result {
return Ok(()); Ok(()) => return Ok(()),
Err(x) => {
error!("Failed to reconnect: {x}");
}
} }
// Decrement remaining retries if we have a limit // Decrement remaining retries if we have a limit

@ -91,8 +91,8 @@ where
/// ///
/// For a client, this means performing an actual [`reconnect`] on the underlying /// For a client, this means performing an actual [`reconnect`] on the underlying
/// [`Transport`], re-establishing an encrypted codec, submitting a request to the server to /// [`Transport`], re-establishing an encrypted codec, submitting a request to the server to
/// reauthenticate using a previously-derived OTP, and refreshing the connection id and OTP for /// reauthenticate using a previously-derived OTP, and refreshing the OTP for use in a future
/// use in a future reauthentication. /// reauthentication.
/// ///
/// ### Server /// ### Server
/// ///
@ -101,10 +101,10 @@ where
/// [`reconnect`]: Reconnectable::reconnect /// [`reconnect`]: Reconnectable::reconnect
async fn reconnect(&mut self) -> io::Result<()> { async fn reconnect(&mut self) -> io::Result<()> {
async fn reconnect_client<T: Transport>( async fn reconnect_client<T: Transport>(
id: &mut ConnectionId, id: ConnectionId,
reauth_otp: &mut HeapSecretKey, reauth_otp: HeapSecretKey,
transport: &mut FramedTransport<T>, transport: &mut FramedTransport<T>,
) -> io::Result<()> { ) -> io::Result<(ConnectionId, HeapSecretKey)> {
// Re-establish a raw connection // Re-establish a raw connection
debug!("[Conn {id}] Re-establishing connection"); debug!("[Conn {id}] Re-establishing connection");
Reconnectable::reconnect(transport).await?; Reconnectable::reconnect(transport).await?;
@ -117,29 +117,16 @@ where
debug!("[Conn {id}] Performing re-authentication"); debug!("[Conn {id}] Performing re-authentication");
transport transport
.write_frame_for(&ConnectType::Reconnect { .write_frame_for(&ConnectType::Reconnect {
id: *id, id,
otp: reauth_otp.unprotected_as_bytes().to_vec(), otp: reauth_otp.unprotected_into_bytes(),
}) })
.await?; .await?;
// Receive the new id for the connection
// NOTE: If we fail re-authentication above,
// this will fail as the connection is dropped
debug!("[Conn {id}] Receiving new connection id");
let new_id = transport
.read_frame_as::<ConnectionId>()
.await?
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Missing connection id frame")
})?;
debug!("[Conn {id}] Resetting id to {new_id}");
*id = new_id;
// Derive an OTP for reauthentication // Derive an OTP for reauthentication
debug!("[Conn {id}] Deriving future OTP for reauthentication"); debug!("[Conn {id}] Deriving future OTP for reauthentication");
*reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
Ok(()) Ok((id, reauth_otp))
} }
match self { match self {
@ -148,19 +135,24 @@ where
transport, transport,
reauth_otp, reauth_otp,
} => { } => {
// Freeze our backup as we don't want the connection logic to alter it // Freeze our backup as we don't want the connection logic to alter it, attempt to
transport.backup.freeze(); // perform the reconnection, and unfreeze our backup regardless of the result
let (new_id, new_reauth_otp) = {
// Attempt to perform the reconnection and unfreeze our backup regardless of the transport.backup.freeze();
// result let result = reconnect_client(*id, reauth_otp.clone(), transport).await;
let result = reconnect_client(id, reauth_otp, transport).await; transport.backup.unfreeze();
transport.backup.unfreeze(); result?
result?; };
// Perform synchronization // Perform synchronization
debug!("[Conn {id}] Synchronizing frame state"); debug!("[Conn {id}] Synchronizing frame state");
transport.synchronize().await?; transport.synchronize().await?;
// Everything has succeeded, so we now will update our id and reauth otp
info!("[Conn {id}] Reconnect completed successfully! Assigning new id {new_id}");
*id = new_id;
*reauth_otp = new_reauth_otp;
Ok(()) Ok(())
} }
@ -234,6 +226,7 @@ where
debug!("[Conn {id}] Deriving future OTP for reauthentication"); debug!("[Conn {id}] Deriving future OTP for reauthentication");
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
info!("[Conn {id}] Connect completed successfully!");
Ok(Self::Client { Ok(Self::Client {
id, id,
reauth_otp, reauth_otp,
@ -283,7 +276,7 @@ where
// Based on the connection type, we either try to find and validate an existing connection // Based on the connection type, we either try to find and validate an existing connection
// or we perform normal verification // or we perform normal verification
match connection_type { let id = match connection_type {
ConnectType::Connect => { ConnectType::Connect => {
// Communicate the connection id // Communicate the connection id
debug!("[Conn {id}] Telling other side to change connection id"); debug!("[Conn {id}] Telling other side to change connection id");
@ -298,42 +291,69 @@ where
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
// Store the id, OTP, and backup retrieval in our database // Store the id, OTP, and backup retrieval in our database
info!("[Conn {id}] Connect completed successfully!");
keychain.insert(id.to_string(), reauth_otp, rx).await; keychain.insert(id.to_string(), reauth_otp, rx).await;
id
} }
ConnectType::Reconnect { id: other_id, otp } => { ConnectType::Reconnect { id: other_id, otp } => {
let reauth_otp = HeapSecretKey::from(otp); let reauth_otp = HeapSecretKey::from(otp);
debug!("[Conn {id}] Checking if {other_id} exists and has matching OTP"); debug!("[Conn {id}] Checking if {other_id} exists and has matching OTP");
match keychain match keychain
.remove_if_has_key(other_id.to_string(), reauth_otp) .remove_if_has_key(other_id.to_string(), reauth_otp.clone())
.await .await
{ {
KeychainResult::Ok(x) => { KeychainResult::Ok(x) => {
// Communicate the connection id // Match found, so we want ot update our id to be the pre-existing id
debug!("[Conn {id}] Telling other side to change connection id"); debug!("[Conn {id}] Reassigning to {other_id}");
transport.write_frame_for(&id).await?; let id = other_id;
// Derive an OTP for reauthentication
debug!("[Conn {id}] Deriving future OTP for reauthentication");
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
// Grab the old backup and swap it into our transport // Grab the old backup
debug!("[Conn {id}] Acquiring backup for existing connection"); debug!("[Conn {id}] Acquiring backup for existing connection");
match x.await { let backup = match x.await {
Ok(backup) => { Ok(backup) => backup,
transport.backup = backup;
}
Err(_) => { Err(_) => {
warn!("[Conn {id}] Missing backup"); warn!("[Conn {id}] Missing backup, will use fresh copy");
Backup::new()
} }
};
macro_rules! unwrap_or_fail {
($action:expr) => {
unwrap_or_fail!(backup, $action)
};
($backup:expr, $action:expr) => {{
match $action {
Ok(x) => x,
Err(x) => {
error!("[Conn {id}] Encountered error, restoring with old backup");
let _ = tx.send($backup);
keychain.insert(id.to_string(), reauth_otp, rx).await;
return Err(x);
}
}
}};
} }
// Derive an OTP for reauthentication
debug!("[Conn {id}] Deriving future OTP for reauthentication");
let new_reauth_otp =
unwrap_or_fail!(transport.exchange_keys().await).into_heap_secret_key();
// Replace our backup with the old one
debug!("[Conn {id}] Restoring backup");
transport.backup = backup;
// Synchronize using the provided backup // Synchronize using the provided backup
debug!("[Conn {id}] Synchronizing frame state"); debug!("[Conn {id}] Synchronizing frame state");
transport.synchronize().await?; unwrap_or_fail!(transport.backup, transport.synchronize().await);
// Store the id, OTP, and backup retrieval in our database // Store the id, OTP, and backup retrieval in our database
keychain.insert(id.to_string(), reauth_otp, rx).await; info!("[Conn {id}] Reconnect restoration completed successfully!");
keychain.insert(id.to_string(), new_reauth_otp, rx).await;
id
} }
KeychainResult::InvalidPassword => { KeychainResult::InvalidPassword => {
return Err(io::Error::new( return Err(io::Error::new(
@ -349,7 +369,7 @@ where
} }
} }
} }
} };
Ok(Self::Server { id, tx, transport }) Ok(Self::Server { id, tx, transport })
} }
@ -384,7 +404,6 @@ impl Connection<InmemoryTransport> {
} }
} }
#[cfg(test)]
impl<T> Connection<T> { impl<T> Connection<T> {
/// Returns the id of the connection. /// Returns the id of the connection.
pub fn id(&self) -> ConnectionId { pub fn id(&self) -> ConnectionId {
@ -393,7 +412,10 @@ impl<T> Connection<T> {
Self::Server { id, .. } => *id, Self::Server { id, .. } => *id,
} }
} }
}
#[cfg(test)]
impl<T> Connection<T> {
/// Returns the OTP associated with the connection, or none if connection is server-side. /// Returns the OTP associated with the connection, or none if connection is server-side.
pub fn otp(&self) -> Option<&HeapSecretKey> { pub fn otp(&self) -> Option<&HeapSecretKey> {
match self { match self {
@ -821,9 +843,6 @@ mod tests {
.await .await
.unwrap(); .unwrap();
// Receive a new client id
let _id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Send garbage to fail the otp exchange // Send garbage to fail the otp exchange
t1.write_frame(Frame::new(b"hello")).await.unwrap(); t1.write_frame(Frame::new(b"hello")).await.unwrap();
@ -862,9 +881,6 @@ mod tests {
.await .await
.unwrap(); .unwrap();
// Receive a new client id
let _id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Perform otp exchange // Perform otp exchange
let _otp = t1.exchange_keys().await.unwrap(); let _otp = t1.exchange_keys().await.unwrap();
@ -928,9 +944,10 @@ mod tests {
let verifier = Verifier::none(); let verifier = Verifier::none();
let keychain = Keychain::new(); let keychain = Keychain::new();
let key = HeapSecretKey::generate(32).unwrap(); let key = HeapSecretKey::generate(32).unwrap();
let id = 1234;
keychain keychain
.insert(1234.to_string(), key.clone(), { .insert(id.to_string(), key.clone(), {
// Create a custom backup we'll use to replay frames from the server-side // Create a custom backup we'll use to replay frames from the server-side
let mut backup = Backup::new(); let mut backup = Backup::new();
@ -968,9 +985,6 @@ mod tests {
.await .await
.unwrap(); .unwrap();
// Receive a new client id
let id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Perform otp exchange // Perform otp exchange
let otp = t1.exchange_keys().await.unwrap(); let otp = t1.exchange_keys().await.unwrap();
@ -996,9 +1010,6 @@ mod tests {
// Validate the connection ids match // Validate the connection ids match
assert_eq!(server.id(), id); assert_eq!(server.id(), id);
// Check that our old connection id is no longer contained in the keychain
assert!(!keychain.has_id("1234").await, "Old OTP still exists");
// Validate the OTP was stored in our keychain // Validate the OTP was stored in our keychain
assert!( assert!(
keychain keychain
@ -1210,12 +1221,6 @@ mod tests {
.await .await
.expect("Failed to retrieve backup"); .expect("Failed to retrieve backup");
// Send a new id back to the client connection
transport
.write_frame_for(&rand::random::<ConnectionId>())
.await
.unwrap();
// Perform key exchange // Perform key exchange
let otp = transport.exchange_keys().await.unwrap(); let otp = transport.exchange_keys().await.unwrap();

@ -1,5 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use std::{io, time::Duration}; use std::{fmt, io, time::Duration};
mod framed; mod framed;
pub use framed::*; pub use framed::*;
@ -42,7 +42,7 @@ pub trait Reconnectable {
/// Interface representing a transport of raw bytes into and out of the system. /// Interface representing a transport of raw bytes into and out of the system.
#[async_trait] #[async_trait]
pub trait Transport: Reconnectable + Send + Sync { pub trait Transport: Reconnectable + fmt::Debug + Send + Sync {
/// Tries to read data from the transport into the provided buffer, returning how many bytes /// Tries to read data from the transport into the provided buffer, returning how many bytes
/// were read. /// were read.
/// ///

@ -254,12 +254,13 @@ impl<T: Transport> FramedTransport<T> {
macro_rules! read_next_frame { macro_rules! read_next_frame {
() => {{ () => {{
match Frame::read(&mut self.incoming) { match Frame::read(&mut self.incoming) {
Ok(None) => (), None => (),
Ok(Some(frame)) => { Some(frame) => {
self.backup.increment_received_cnt(); if frame.is_nonempty() {
self.backup.increment_received_cnt();
}
return Ok(Some(self.codec.decode(frame)?.into_owned())); return Ok(Some(self.codec.decode(frame)?.into_owned()));
} }
Err(x) => return Err(x),
} }
}}; }};
} }
@ -363,14 +364,17 @@ impl<T: Transport> FramedTransport<T> {
// Encode the frame and store it in our outgoing queue // Encode the frame and store it in our outgoing queue
self.codec self.codec
.encode(frame.as_borrowed())? .encode(frame.as_borrowed())?
.write(&mut self.outgoing)?; .write(&mut self.outgoing);
// Once the frame enters our queue, we count it as written, even if it isn't fully flushed // Update tracking stats and more of backup if frame is nonempty
self.backup.increment_sent_cnt(); if frame.is_nonempty() {
// Once the frame enters our queue, we count it as written, even if it isn't fully flushed
self.backup.increment_sent_cnt();
// Then we store the raw frame (non-encoded) for the future in case we need to retry // Then we store the raw frame (non-encoded) for the future in case we need to retry
// sending it later (possibly with a different codec) // sending it later (possibly with a different codec)
self.backup.push_frame(frame); self.backup.push_frame(frame);
}
// Attempt to write everything in our queue // Attempt to write everything in our queue
self.try_flush()?; self.try_flush()?;
@ -535,7 +539,7 @@ impl<T: Transport> FramedTransport<T> {
// Encode our frame and write it to be queued in our incoming data // Encode our frame and write it to be queued in our incoming data
// NOTE: We have to do encoding here as incoming bytes are expected to be encoded // NOTE: We have to do encoding here as incoming bytes are expected to be encoded
this.codec.encode(frame)?.write(&mut this.incoming)?; this.codec.encode(frame)?.write(&mut this.incoming);
} }
// Catch up our read count as we can have the case where the other side has a higher // Catch up our read count as we can have the case where the other side has a higher
@ -701,7 +705,7 @@ impl<T: Transport> FramedTransport<T> {
})?; })?;
// Choose a compression and encryption option from the options // Choose a compression and encryption option from the options
debug!("[{log_label}] Selecting from options: {options:#?}"); debug!("[{log_label}] Selecting from options: {options:?}");
let choice = Choice { let choice = Choice {
// Use preferred compression if available, otherwise default to no compression // Use preferred compression if available, otherwise default to no compression
// to avoid choosing something poor // to avoid choosing something poor
@ -725,7 +729,7 @@ impl<T: Transport> FramedTransport<T> {
}; };
// Report back to the server the choice // Report back to the server the choice
debug!("[{log_label}] Reporting choice: {choice:#?}"); debug!("[{log_label}] Reporting choice: {choice:?}");
self.write_frame_for(&choice).await?; self.write_frame_for(&choice).await?;
choice choice
@ -740,7 +744,7 @@ impl<T: Transport> FramedTransport<T> {
}; };
// Send options to the client // Send options to the client
debug!("[{log_label}] Sending options: {options:#?}"); debug!("[{log_label}] Sending options: {options:?}");
self.write_frame_for(&options).await?; self.write_frame_for(&options).await?;
// Get client's response with selected compression and encryption // Get client's response with selected compression and encryption
@ -754,7 +758,7 @@ impl<T: Transport> FramedTransport<T> {
} }
}; };
debug!("[{log_label}] Building compression & encryption codecs based on {choice:#?}"); debug!("[{log_label}] Building compression & encryption codecs based on {choice:?}");
let compression_level = choice.compression_level.unwrap_or_default(); let compression_level = choice.compression_level.unwrap_or_default();
// Acquire a codec for the compression type // Acquire a codec for the compression type
@ -968,7 +972,7 @@ mod tests {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
for frame in frames { for frame in frames {
frame.write(&mut buf).unwrap(); frame.write(&mut buf);
} }
buf.to_vec() buf.to_vec()
@ -1059,7 +1063,7 @@ mod tests {
fn try_read_frame_should_return_next_available_frame() { fn try_read_frame_should_return_next_available_frame() {
let data = { let data = {
let mut data = BytesMut::new(); let mut data = BytesMut::new();
Frame::new(b"hello world").write(&mut data).unwrap(); Frame::new(b"hello world").write(&mut data);
data.freeze() data.freeze()
}; };
@ -1082,8 +1086,8 @@ mod tests {
// Store two frames in our data to transmit // Store two frames in our data to transmit
let data = { let data = {
let mut data = BytesMut::new(); let mut data = BytesMut::new();
Frame::new(b"hello world").write(&mut data).unwrap(); Frame::new(b"hello world").write(&mut data);
Frame::new(b"hello again").write(&mut data).unwrap(); Frame::new(b"hello again").write(&mut data);
data.freeze() data.freeze()
}; };
@ -1746,8 +1750,8 @@ mod tests {
let (mut t1, mut t2) = FramedTransport::pair(100); let (mut t1, mut t2) = FramedTransport::pair(100);
// Put some frames into the incoming and outgoing of our transport // Put some frames into the incoming and outgoing of our transport
Frame::new(b"bad incoming").write(&mut t2.incoming).unwrap(); Frame::new(b"bad incoming").write(&mut t2.incoming);
Frame::new(b"bad outgoing").write(&mut t2.outgoing).unwrap(); Frame::new(b"bad outgoing").write(&mut t2.outgoing);
// Configure the backup such that we have sent two frames // Configure the backup such that we have sent two frames
t2.backup.push_frame(Frame::new(b"hello")); t2.backup.push_frame(Frame::new(b"hello"));

@ -5,6 +5,11 @@ use std::collections::VecDeque;
const MAX_BACKUP_SIZE: usize = 256 * 1024 * 1024; const MAX_BACKUP_SIZE: usize = 256 * 1024 * 1024;
/// Stores [`Frame`]s for reuse later. /// Stores [`Frame`]s for reuse later.
///
/// ### Note
///
/// Empty [`Frame`]s are an exception and are not stored within the backup nor
/// are they tracked in terms of sent/received counts.
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct Backup { pub struct Backup {
/// Maximum size (in bytes) to save frames in case we need to backup them /// Maximum size (in bytes) to save frames in case we need to backup them

@ -1,5 +1,5 @@
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use std::{borrow::Cow, io}; use std::borrow::Cow;
/// Represents a frame whose lifetime is static /// Represents a frame whose lifetime is static
pub type OwnedFrame = Frame<'static>; pub type OwnedFrame = Frame<'static>;
@ -13,7 +13,7 @@ pub struct Frame<'a> {
} }
impl<'a> Frame<'a> { impl<'a> Frame<'a> {
/// Creates a new frame wrapping the `item` that will be shipped across the network /// Creates a new frame wrapping the `item` that will be shipped across the network.
pub fn new(item: &'a [u8]) -> Self { pub fn new(item: &'a [u8]) -> Self {
Self { Self {
item: Cow::Borrowed(item), item: Cow::Borrowed(item),
@ -27,75 +27,66 @@ impl<'a> Frame<'a> {
} }
impl Frame<'_> { impl Frame<'_> {
/// Total bytes to use as the header field denoting a frame's size /// Total bytes to use as the header field denoting a frame's size.
pub const HEADER_SIZE: usize = 8; pub const HEADER_SIZE: usize = 8;
/// Returns the len (in bytes) of the item wrapped by the frame /// Creates a new frame with no item.
pub fn empty() -> Self {
Self::new(&[])
}
/// Returns the len (in bytes) of the item wrapped by the frame.
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.item.len() self.item.len()
} }
/// Returns true if the frame is comprised of zero bytes /// Returns true if the frame is comprised of zero bytes.
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.item.is_empty() self.item.is_empty()
} }
/// Returns a reference to the bytes of the frame's item /// Returns true if the frame is comprised of some bytes.
#[inline]
pub fn is_nonempty(&self) -> bool {
!self.is_empty()
}
/// Returns a reference to the bytes of the frame's item.
pub fn as_item(&self) -> &[u8] { pub fn as_item(&self) -> &[u8] {
&self.item &self.item
} }
/// Writes the frame to a new [`Vec`] of bytes, returning them on success /// Writes the frame to a new [`Vec`] of bytes, returning them on success.
pub fn try_to_bytes(&self) -> io::Result<Vec<u8>> { pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = BytesMut::new(); let mut bytes = BytesMut::new();
self.write(&mut bytes)?; self.write(&mut bytes);
Ok(bytes.to_vec()) bytes.to_vec()
} }
/// Writes the frame to the end of `dst`, including the header representing the length of the /// Writes the frame to the end of `dst`, including the header representing the length of the
/// item as part of the written bytes /// item as part of the written bytes.
pub fn write(&self, dst: &mut BytesMut) -> io::Result<()> { pub fn write(&self, dst: &mut BytesMut) {
if self.item.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Empty item provided",
));
}
dst.reserve(Self::HEADER_SIZE + self.item.len()); dst.reserve(Self::HEADER_SIZE + self.item.len());
// Add data in form of {LEN}{ITEM} // Add data in form of {LEN}{ITEM}
dst.put_u64((self.item.len()) as u64); dst.put_u64((self.item.len()) as u64);
dst.put_slice(&self.item); dst.put_slice(&self.item);
Ok(())
} }
/// Attempts to read a frame from `src`, returning `Some(Frame)` if a frame was found /// Attempts to read a frame from `src`, returning `Some(Frame)` if a frame was found
/// (including the header) or `None` if the current `src` does not contain a frame /// (including the header) or `None` if the current `src` does not contain a frame.
pub fn read(src: &mut BytesMut) -> io::Result<Option<OwnedFrame>> { pub fn read(src: &mut BytesMut) -> Option<OwnedFrame> {
// First, check if we have more data than just our frame's message length // First, check if we have more data than just our frame's message length
if src.len() <= Self::HEADER_SIZE { if src.len() <= Self::HEADER_SIZE {
return Ok(None); return None;
} }
// Second, retrieve total size of our frame's message // Second, retrieve total size of our frame's message
let item_len = u64::from_be_bytes(src[..Self::HEADER_SIZE].try_into().unwrap()) as usize; let item_len = u64::from_be_bytes(src[..Self::HEADER_SIZE].try_into().unwrap()) as usize;
// In the case that our item len is 0, we skip over the invalid frame
if item_len == 0 {
// Ensure we advance to remove the frame
src.advance(Self::HEADER_SIZE);
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Frame's msg cannot have length of 0",
));
}
// Third, check if we have all data for our frame; if not, exit early // Third, check if we have all data for our frame; if not, exit early
if src.len() < item_len + Self::HEADER_SIZE { if src.len() < item_len + Self::HEADER_SIZE {
return Ok(None); return None;
} }
// Fourth, get and return our item // Fourth, get and return our item
@ -104,13 +95,13 @@ impl Frame<'_> {
// Fifth, advance so frame is no longer kept around // Fifth, advance so frame is no longer kept around
src.advance(Self::HEADER_SIZE + item_len); src.advance(Self::HEADER_SIZE + item_len);
Ok(Some(Frame::from(item))) Some(Frame::from(item))
} }
/// Checks if a full frame is available from `src`, returning true if a frame was found false /// Checks if a full frame is available from `src`, returning true if a frame was found false
/// if the current `src` does not contain a frame. Does not consume the frame. /// if the current `src` does not contain a frame. Does not consume the frame.
pub fn available(src: &BytesMut) -> bool { pub fn available(src: &BytesMut) -> bool {
matches!(Frame::read(&mut src.clone()), Ok(Some(_))) matches!(Frame::read(&mut src.clone()), Some(_))
} }
/// Returns a new frame which is identical but has a lifetime tied to this frame. /// Returns a new frame which is identical but has a lifetime tied to this frame.
@ -239,16 +230,15 @@ mod tests {
use test_log::test; use test_log::test;
#[test] #[test]
fn write_should_fail_when_item_is_zero_bytes() { fn write_should_succeed_when_item_is_zero_bytes() {
let frame = Frame::new(&[]); let frame = Frame::new(&[]);
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
let result = frame.write(&mut buf); frame.write(&mut buf);
match result { // Writing a frame of zero bytes means that the header is all zeros and there is
Err(x) if x.kind() == io::ErrorKind::InvalidInput => {} // no item that follows the header
x => panic!("Unexpected result: {:?}", x), assert_eq!(buf.as_ref(), &[0, 0, 0, 0, 0, 0, 0, 0]);
}
} }
#[test] #[test]
@ -256,7 +246,7 @@ mod tests {
let frame = Frame::new(b"hello, world"); let frame = Frame::new(b"hello, world");
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
frame.write(&mut buf).expect("Failed to write"); frame.write(&mut buf);
let len = buf.get_u64() as usize; let len = buf.get_u64() as usize;
assert_eq!(len, 12, "Wrong length writed"); assert_eq!(len, 12, "Wrong length writed");
@ -269,11 +259,7 @@ mod tests {
buf.put_bytes(0, Frame::HEADER_SIZE); buf.put_bytes(0, Frame::HEADER_SIZE);
let result = Frame::read(&mut buf); let result = Frame::read(&mut buf);
assert!( assert!(matches!(result, None), "Unexpected result: {:?}", result);
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
} }
#[test] #[test]
@ -282,24 +268,21 @@ mod tests {
buf.put_u64(0); buf.put_u64(0);
let result = Frame::read(&mut buf); let result = Frame::read(&mut buf);
assert!( assert!(matches!(result, None), "Unexpected result: {:?}", result);
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
} }
#[test] #[test]
fn read_should_fail_if_writed_item_length_is_zero() { fn read_should_succeed_if_written_item_length_is_zero() {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
buf.put_u64(0); buf.put_u64(0);
buf.put_u8(255); buf.put_u8(255);
let result = Frame::read(&mut buf); // Reading will result in a frame of zero bytes
match result { let frame = Frame::read(&mut buf).expect("missing frame");
Err(x) if x.kind() == io::ErrorKind::InvalidData => {} assert_eq!(frame, Frame::empty());
x => panic!("Unexpected result: {:?}", x),
} // Nothing following the frame header should have been extracted
assert_eq!(buf.as_ref(), &[255]);
} }
#[test] #[test]
@ -308,10 +291,7 @@ mod tests {
buf.put_u64(0); buf.put_u64(0);
buf.put_bytes(0, 3); buf.put_bytes(0, 3);
assert!( assert_eq!(Frame::read(&mut buf).unwrap(), Frame::empty());
Frame::read(&mut buf).is_err(),
"read unexpectedly succeeded"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf"); assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
} }
@ -319,25 +299,22 @@ mod tests {
fn read_should_advance_src_by_frame_size_when_successful() { fn read_should_advance_src_by_frame_size_when_successful() {
// Add 3 extra bytes after a full frame // Add 3 extra bytes after a full frame
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
Frame::new(b"hello, world") Frame::new(b"hello, world").write(&mut buf);
.write(&mut buf)
.expect("Failed to write");
buf.put_bytes(0, 3); buf.put_bytes(0, 3);
assert!(Frame::read(&mut buf).is_ok(), "read unexpectedly failed"); assert!(
Frame::read(&mut buf).is_some(),
"read unexpectedly missing frame"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf"); assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
} }
#[test] #[test]
fn read_should_return_some_byte_vec_when_successful() { fn read_should_return_some_byte_vec_when_successful() {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
Frame::new(b"hello, world") Frame::new(b"hello, world").write(&mut buf);
.write(&mut buf)
.expect("Failed to write");
let item = Frame::read(&mut buf) let item = Frame::read(&mut buf).expect("missing frame");
.expect("Failed to read")
.expect("Item not properly captured");
assert_eq!(item, b"hello, world"); assert_eq!(item, b"hello, world");
} }
} }

@ -1,6 +1,6 @@
use super::{Interest, Ready, Reconnectable, Transport}; use super::{Interest, Ready, Reconnectable, Transport};
use async_trait::async_trait; use async_trait::async_trait;
use std::io; use std::{fmt, io};
pub type TryReadFn = Box<dyn Fn(&mut [u8]) -> io::Result<usize> + Send + Sync>; pub type TryReadFn = Box<dyn Fn(&mut [u8]) -> io::Result<usize> + Send + Sync>;
pub type TryWriteFn = Box<dyn Fn(&[u8]) -> io::Result<usize> + Send + Sync>; pub type TryWriteFn = Box<dyn Fn(&[u8]) -> io::Result<usize> + Send + Sync>;
@ -25,6 +25,12 @@ impl Default for TestTransport {
} }
} }
impl fmt::Debug for TestTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TestTransport").finish()
}
}
#[async_trait] #[async_trait]
impl Reconnectable for TestTransport { impl Reconnectable for TestTransport {
async fn reconnect(&mut self) -> io::Result<()> { async fn reconnect(&mut self) -> io::Result<()> {

@ -303,13 +303,13 @@ impl ManagerClient {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::client::{ReconnectStrategy, UntypedClient}; use crate::client::UntypedClient;
use crate::common::authentication::DummyAuthHandler; use crate::common::authentication::DummyAuthHandler;
use crate::common::{Connection, InmemoryTransport, Request, Response}; use crate::common::{Connection, InmemoryTransport, Request, Response};
fn setup() -> (ManagerClient, Connection<InmemoryTransport>) { fn setup() -> (ManagerClient, Connection<InmemoryTransport>) {
let (client, server) = Connection::pair(100); let (client, server) = Connection::pair(100);
let client = UntypedClient::spawn(client, ReconnectStrategy::Fail).into_typed_client(); let client = UntypedClient::spawn(client, Default::default()).into_typed_client();
(client, server) (client, server)
} }

@ -1,5 +1,5 @@
use crate::{ use crate::{
client::{Client, ReconnectStrategy, UntypedClient}, client::{Client, ClientConfig, UntypedClient},
common::{ConnectionId, FramedTransport, InmemoryTransport, UntypedRequest}, common::{ConnectionId, FramedTransport, InmemoryTransport, UntypedRequest},
manager::data::{ManagerRequest, ManagerResponse}, manager::data::{ManagerRequest, ManagerResponse},
}; };
@ -35,7 +35,10 @@ impl RawChannel {
T: Send + Sync + Serialize + 'static, T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static, U: Send + Sync + DeserializeOwned + 'static,
{ {
Client::spawn_inmemory(self.transport, ReconnectStrategy::Fail) Client::spawn_inmemory(
self.transport,
ClientConfig::default().with_maximum_silence_duration(),
)
} }
/// Consumes this channel, returning an untyped client wrapping the transport. /// Consumes this channel, returning an untyped client wrapping the transport.
@ -46,7 +49,10 @@ impl RawChannel {
/// performed during separate connection and this merely wraps an inmemory transport that maps /// performed during separate connection and this merely wraps an inmemory transport that maps
/// to the primary connection. /// to the primary connection.
pub fn into_untyped_client(self) -> UntypedClient { pub fn into_untyped_client(self) -> UntypedClient {
UntypedClient::spawn_inmemory(self.transport, ReconnectStrategy::Fail) UntypedClient::spawn_inmemory(
self.transport,
ClientConfig::default().with_maximum_silence_duration(),
)
} }
/// Returns reference to the underlying framed transport. /// Returns reference to the underlying framed transport.

@ -316,7 +316,7 @@ impl ServerHandler for ManagerServer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::client::{ReconnectStrategy, UntypedClient}; use crate::client::UntypedClient;
use crate::common::FramedTransport; use crate::common::FramedTransport;
use crate::server::ServerReply; use crate::server::ServerReply;
use crate::{boxed_connect_handler, boxed_launch_handler}; use crate::{boxed_connect_handler, boxed_launch_handler};
@ -335,7 +335,7 @@ mod tests {
/// Create an untyped client that is detached such that reads and writes will fail /// Create an untyped client that is detached such that reads and writes will fail
fn detached_untyped_client() -> UntypedClient { fn detached_untyped_client() -> UntypedClient {
UntypedClient::spawn_inmemory(FramedTransport::pair(1).0, ReconnectStrategy::Fail) UntypedClient::spawn_inmemory(FramedTransport::pair(1).0, Default::default())
} }
/// Create a new server and authenticator /// Create a new server and authenticator

@ -1,9 +1,9 @@
use crate::common::{authentication::Verifier, Listener, Transport}; use crate::common::{authentication::Verifier, Listener, Response, Transport};
use async_trait::async_trait; use async_trait::async_trait;
use log::*; use log::*;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::{io, sync::Arc, time::Duration}; use std::{io, sync::Arc, time::Duration};
use tokio::sync::RwLock; use tokio::sync::{broadcast, RwLock};
mod builder; mod builder;
pub use builder::*; pub use builder::*;
@ -148,14 +148,20 @@ where
L::Output: Transport + 'static, L::Output: Transport + 'static,
{ {
let state = Arc::new(ServerState::new()); let state = Arc::new(ServerState::new());
let task = tokio::spawn(self.task(Arc::clone(&state), listener)); let (tx, rx) = broadcast::channel(1);
let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
Ok(Box::new(GenericServerRef { state, task })) Ok(Box::new(GenericServerRef { shutdown: tx, task }))
} }
/// Internal task that is run to receive connections and spawn connection tasks /// Internal task that is run to receive connections and spawn connection tasks
async fn task<L>(self, state: Arc<ServerState>, mut listener: L) async fn task<L>(
where self,
state: Arc<ServerState<Response<T::Response>>>,
mut listener: L,
shutdown_tx: broadcast::Sender<()>,
shutdown_rx: broadcast::Receiver<()>,
) where
L: Listener + 'static, L: Listener + 'static,
L::Output: Transport + 'static, L::Output: Transport + 'static,
{ {
@ -171,6 +177,7 @@ where
let timer = Arc::new(RwLock::new(timer)); let timer = Arc::new(RwLock::new(timer));
let verifier = Arc::new(verifier); let verifier = Arc::new(verifier);
let mut connection_tasks = Vec::new();
loop { loop {
// Receive a new connection, exiting if no longer accepting connections or if the shutdown // Receive a new connection, exiting if no longer accepting connections or if the shutdown
// signal has been received // signal has been received
@ -191,10 +198,7 @@ where
config.shutdown.duration().unwrap_or_default().as_secs_f32(), config.shutdown.duration().unwrap_or_default().as_secs_f32(),
); );
for (id, task) in state.connections.write().await.drain() { let _ = shutdown_tx.send(());
info!("Terminating task {id}");
task.abort();
}
break; break;
} }
@ -203,26 +207,28 @@ where
// Ensure that the shutdown timer is cancelled now that we have a connection // Ensure that the shutdown timer is cancelled now that we have a connection
timer.read().await.stop(); timer.read().await.stop();
let connection = ConnectionTask::build() connection_tasks.push(
.handler(Arc::downgrade(&handler)) ConnectionTask::build()
.state(Arc::downgrade(&state)) .handler(Arc::downgrade(&handler))
.keychain(state.keychain.clone()) .state(Arc::downgrade(&state))
.transport(transport) .keychain(state.keychain.clone())
.shutdown_timer(Arc::downgrade(&timer)) .transport(transport)
.sleep_duration(config.connection_sleep) .shutdown(shutdown_rx.resubscribe())
.verifier(Arc::downgrade(&verifier)) .shutdown_timer(Arc::downgrade(&timer))
.spawn(); .sleep_duration(config.connection_sleep)
.heartbeat_duration(config.connection_heartbeat)
state .verifier(Arc::downgrade(&verifier))
.connections .spawn(),
.write() );
.await
.insert(connection.id(), connection);
} }
// Once we stop listening, we still want to wait until all connections have terminated // Once we stop listening, we still want to wait until all connections have terminated
info!("Server waiting for active connections to terminate"); info!("Server waiting for active connections to terminate");
while state.has_active_connections().await { loop {
connection_tasks.retain(|task| !task.is_finished());
if connection_tasks.is_empty() {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await; tokio::time::sleep(Duration::from_millis(50)).await;
} }
info!("Server task terminated"); info!("Server task terminated");

@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
use std::{num::ParseFloatError, str::FromStr, time::Duration}; use std::{num::ParseFloatError, str::FromStr, time::Duration};
const DEFAULT_CONNECTION_SLEEP: Duration = Duration::from_millis(1); const DEFAULT_CONNECTION_SLEEP: Duration = Duration::from_millis(1);
const DEFAULT_HEARTBEAT_DURATION: Duration = Duration::from_secs(5);
/// Represents a general-purpose set of properties tied with a server instance /// Represents a general-purpose set of properties tied with a server instance
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
@ -10,6 +11,9 @@ pub struct ServerConfig {
/// Time to wait inbetween connection read/write when nothing was read or written on last pass /// Time to wait inbetween connection read/write when nothing was read or written on last pass
pub connection_sleep: Duration, pub connection_sleep: Duration,
/// Minimum time to wait inbetween sending heartbeat messages
pub connection_heartbeat: Duration,
/// Rules for how a server will shutdown automatically /// Rules for how a server will shutdown automatically
pub shutdown: Shutdown, pub shutdown: Shutdown,
} }
@ -18,6 +22,7 @@ impl Default for ServerConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
connection_sleep: DEFAULT_CONNECTION_SLEEP, connection_sleep: DEFAULT_CONNECTION_SLEEP,
connection_heartbeat: DEFAULT_HEARTBEAT_DURATION,
shutdown: Default::default(), shutdown: Default::default(),
} }
} }

@ -1,7 +1,10 @@
use super::{ConnectionCtx, ServerCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer}; use super::{
ConnectionCtx, ConnectionState, ServerCtx, ServerHandler, ServerReply, ServerState,
ShutdownTimer,
};
use crate::common::{ use crate::common::{
authentication::{Keychain, Verifier}, authentication::{Keychain, Verifier},
Backup, Connection, ConnectionId, Interest, Response, Transport, UntypedRequest, Backup, Connection, Frame, Interest, Response, Transport, UntypedRequest,
}; };
use log::*; use log::*;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
@ -11,56 +14,33 @@ use std::{
pin::Pin, pin::Pin,
sync::{Arc, Weak}, sync::{Arc, Weak},
task::{Context, Poll}, task::{Context, Poll},
time::Duration, time::{Duration, Instant},
}; };
use tokio::{ use tokio::{
sync::{mpsc, oneshot, RwLock}, sync::{broadcast, mpsc, oneshot, RwLock},
task::JoinHandle, task::JoinHandle,
}; };
pub type ServerKeychain = Keychain<oneshot::Receiver<Backup>>; pub type ServerKeychain = Keychain<oneshot::Receiver<Backup>>;
/// Time to wait inbetween connection read/write when nothing was read or written on last pass /// Time to wait inbetween connection read/write when nothing was read or written on last pass.
const SLEEP_DURATION: Duration = Duration::from_millis(1); const SLEEP_DURATION: Duration = Duration::from_millis(1);
/// Represents an individual connection on the server /// Minimum time between heartbeats to communicate to the client connection.
pub struct ConnectionTask { const MINIMUM_HEARTBEAT_DURATION: Duration = Duration::from_secs(5);
/// Unique identifier tied to the connection
id: ConnectionId,
/// Task that is processing requests and responses /// Represents an individual connection on the server.
task: JoinHandle<io::Result<()>>, pub(super) struct ConnectionTask(JoinHandle<io::Result<()>>);
}
impl ConnectionTask { impl ConnectionTask {
/// Starts building a new connection /// Starts building a new connection
pub fn build() -> ConnectionTaskBuilder<(), ()> { pub fn build() -> ConnectionTaskBuilder<(), (), ()> {
let id: ConnectionId = rand::random(); ConnectionTaskBuilder::new()
ConnectionTaskBuilder {
id,
handler: Weak::new(),
state: Weak::new(),
keychain: Keychain::new(),
transport: (),
shutdown_timer: Weak::new(),
sleep_duration: SLEEP_DURATION,
verifier: Weak::new(),
}
}
/// Returns the id associated with the connection
pub fn id(&self) -> ConnectionId {
self.id
} }
/// Returns true if the task has finished /// Returns true if the task has finished
pub fn is_finished(&self) -> bool { pub fn is_finished(&self) -> bool {
self.task.is_finished() self.0.is_finished()
}
/// Aborts the connection
pub fn abort(&self) {
self.task.abort();
} }
} }
@ -68,7 +48,7 @@ impl Future for ConnectionTask {
type Output = io::Result<()>; type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Future::poll(Pin::new(&mut self.task), cx) { match Future::poll(Pin::new(&mut self.0), cx) {
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(x) => match x { Poll::Ready(x) => match x {
Ok(x) => Poll::Ready(x), Ok(x) => Poll::Ready(x),
@ -78,114 +58,171 @@ impl Future for ConnectionTask {
} }
} }
pub struct ConnectionTaskBuilder<H, T> { /// Represents a builder for a new connection task.
id: ConnectionId, pub(super) struct ConnectionTaskBuilder<H, S, T> {
handler: Weak<H>, handler: Weak<H>,
state: Weak<ServerState>, state: Weak<ServerState<S>>,
keychain: Keychain<oneshot::Receiver<Backup>>, keychain: Keychain<oneshot::Receiver<Backup>>,
transport: T, transport: T,
shutdown: broadcast::Receiver<()>,
shutdown_timer: Weak<RwLock<ShutdownTimer>>, shutdown_timer: Weak<RwLock<ShutdownTimer>>,
sleep_duration: Duration, sleep_duration: Duration,
heartbeat_duration: Duration,
verifier: Weak<Verifier>, verifier: Weak<Verifier>,
} }
impl<H, T> ConnectionTaskBuilder<H, T> { impl ConnectionTaskBuilder<(), (), ()> {
pub fn handler<U>(self, handler: Weak<U>) -> ConnectionTaskBuilder<U, T> { /// Starts building a new connection.
pub fn new() -> Self {
Self {
handler: Weak::new(),
state: Weak::new(),
keychain: Keychain::new(),
transport: (),
shutdown: broadcast::channel(1).1,
shutdown_timer: Weak::new(),
sleep_duration: SLEEP_DURATION,
heartbeat_duration: MINIMUM_HEARTBEAT_DURATION,
verifier: Weak::new(),
}
}
}
impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
pub fn handler<U>(self, handler: Weak<U>) -> ConnectionTaskBuilder<U, S, T> {
ConnectionTaskBuilder { ConnectionTaskBuilder {
id: self.id,
handler, handler,
state: self.state, state: self.state,
keychain: self.keychain, keychain: self.keychain,
transport: self.transport, transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer, shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration, sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier, verifier: self.verifier,
} }
} }
pub fn state(self, state: Weak<ServerState>) -> ConnectionTaskBuilder<H, T> { pub fn state<U>(self, state: Weak<ServerState<U>>) -> ConnectionTaskBuilder<H, U, T> {
ConnectionTaskBuilder { ConnectionTaskBuilder {
id: self.id,
handler: self.handler, handler: self.handler,
state, state,
keychain: self.keychain, keychain: self.keychain,
transport: self.transport, transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer, shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration, sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier, verifier: self.verifier,
} }
} }
pub fn keychain(self, keychain: ServerKeychain) -> ConnectionTaskBuilder<H, T> { pub fn keychain(self, keychain: ServerKeychain) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder { ConnectionTaskBuilder {
id: self.id,
handler: self.handler, handler: self.handler,
state: self.state, state: self.state,
keychain, keychain,
transport: self.transport, transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer, shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration, sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier, verifier: self.verifier,
} }
} }
pub fn transport<U>(self, transport: U) -> ConnectionTaskBuilder<H, U> { pub fn transport<U>(self, transport: U) -> ConnectionTaskBuilder<H, S, U> {
ConnectionTaskBuilder { ConnectionTaskBuilder {
id: self.id,
handler: self.handler, handler: self.handler,
keychain: self.keychain, keychain: self.keychain,
state: self.state, state: self.state,
transport, transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
}
}
pub fn shutdown(self, shutdown: broadcast::Receiver<()>) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder {
handler: self.handler,
state: self.state,
keychain: self.keychain,
transport: self.transport,
shutdown,
shutdown_timer: self.shutdown_timer, shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration, sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier, verifier: self.verifier,
} }
} }
pub(crate) fn shutdown_timer( pub fn shutdown_timer(
self, self,
shutdown_timer: Weak<RwLock<ShutdownTimer>>, shutdown_timer: Weak<RwLock<ShutdownTimer>>,
) -> ConnectionTaskBuilder<H, T> { ) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder { ConnectionTaskBuilder {
id: self.id,
handler: self.handler, handler: self.handler,
state: self.state, state: self.state,
keychain: self.keychain, keychain: self.keychain,
transport: self.transport, transport: self.transport,
shutdown: self.shutdown,
shutdown_timer, shutdown_timer,
sleep_duration: self.sleep_duration, sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier, verifier: self.verifier,
} }
} }
pub fn sleep_duration(self, sleep_duration: Duration) -> ConnectionTaskBuilder<H, T> { pub fn sleep_duration(self, sleep_duration: Duration) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder { ConnectionTaskBuilder {
id: self.id,
handler: self.handler, handler: self.handler,
state: self.state, state: self.state,
keychain: self.keychain, keychain: self.keychain,
transport: self.transport, transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer, shutdown_timer: self.shutdown_timer,
sleep_duration, sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
}
}
pub fn heartbeat_duration(
self,
heartbeat_duration: Duration,
) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder {
handler: self.handler,
state: self.state,
keychain: self.keychain,
transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration,
verifier: self.verifier, verifier: self.verifier,
} }
} }
pub fn verifier(self, verifier: Weak<Verifier>) -> ConnectionTaskBuilder<H, T> { pub fn verifier(self, verifier: Weak<Verifier>) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder { ConnectionTaskBuilder {
id: self.id,
handler: self.handler, handler: self.handler,
state: self.state, state: self.state,
keychain: self.keychain, keychain: self.keychain,
transport: self.transport, transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer, shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration, sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier, verifier,
} }
} }
} }
impl<H, T> ConnectionTaskBuilder<H, T> impl<H, T> ConnectionTaskBuilder<H, Response<H::Response>, T>
where where
H: ServerHandler + Sync + 'static, H: ServerHandler + Sync + 'static,
H::Request: DeserializeOwned + Send + Sync + 'static, H::Request: DeserializeOwned + Send + Sync + 'static,
@ -194,52 +231,86 @@ where
T: Transport + 'static, T: Transport + 'static,
{ {
pub fn spawn(self) -> ConnectionTask { pub fn spawn(self) -> ConnectionTask {
let id = self.id; ConnectionTask(tokio::spawn(self.run()))
ConnectionTask {
id,
task: tokio::spawn(self.run()),
}
} }
async fn run(self) -> io::Result<()> { async fn run(self) -> io::Result<()> {
let ConnectionTaskBuilder { let ConnectionTaskBuilder {
id,
handler, handler,
state, state,
keychain, keychain,
transport, transport,
mut shutdown,
shutdown_timer, shutdown_timer,
sleep_duration, sleep_duration,
heartbeat_duration,
verifier, verifier,
} = self; } = self;
// NOTE: This exists purely to make the compiler happy for macro_rules declaration order.
let (mut local_shutdown, channel_tx, connection_state) = ConnectionState::channel();
// Will check if no more connections and restart timer if that's the case // Will check if no more connections and restart timer if that's the case
macro_rules! terminate_connection { macro_rules! terminate_connection {
// Prints an error message before terminating the connection by panicking // Prints an error message and does not store state
(@error $($msg:tt)+) => { (@fatal $($msg:tt)+) => {
error!($($msg)+); error!($($msg)+);
terminate_connection!(); terminate_connection!();
return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+))); return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+)));
}; };
// Prints a debug message before terminating the connection by cleanly returning // Prints an error message and stores state before terminating
(@debug $($msg:tt)+) => { (@error($tx:ident, $rx:ident) $($msg:tt)+) => {
error!($($msg)+);
terminate_connection!($tx, $rx);
return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+)));
};
// Prints a debug message and stores state before terminating
(@debug($tx:ident, $rx:ident) $($msg:tt)+) => {
debug!($($msg)+); debug!($($msg)+);
terminate_connection!($tx, $rx);
return Ok(());
};
// Prints a shutdown message with no connection id and exit without sending state
(@shutdown) => {
debug!("Shutdown triggered before a connection could be fully established");
terminate_connection!();
return Ok(());
};
// Prints a shutdown message with no connection id and stores state before terminating
(@shutdown) => {
debug!("Shutdown triggered before a connection could be fully established");
terminate_connection!(); terminate_connection!();
return Ok(()); return Ok(());
}; };
// Prints a shutdown message and stores state before terminating
(@shutdown($id:ident, $tx:ident, $rx:ident)) => {{
debug!("[Conn {}] Shutdown triggered", $id);
terminate_connection!($tx, $rx);
return Ok(());
}};
// Performs the connection termination by removing it from server state and
// restarting the shutdown timer if it was the last connection
($tx:ident, $rx:ident) => {
// Send the channels back
let _ = channel_tx.send(($tx, $rx));
terminate_connection!();
};
// Performs the connection termination by removing it from server state and // Performs the connection termination by removing it from server state and
// restarting the shutdown timer if it was the last connection // restarting the shutdown timer if it was the last connection
() => { () => {
// Remove the connection from our state if it has closed // Restart our shutdown timer if this is the last connection
if let Some(state) = Weak::upgrade(&state) { if let Some(state) = Weak::upgrade(&state) {
state.connections.write().await.remove(&self.id);
// If we have no more connections, start the timer
if let Some(timer) = Weak::upgrade(&shutdown_timer) { if let Some(timer) = Weak::upgrade(&shutdown_timer) {
if state.connections.read().await.is_empty() { if state.connections.read().await.values().filter(|conn| !conn.is_finished()).count() <= 1 {
debug!("Last connection terminating, so restarting shutdown timer");
timer.write().await.restart(); timer.write().await.restart();
} }
} }
@ -247,58 +318,160 @@ where
}; };
} }
/// Awaits a future to complete, or detects if a signal was received by either the global
/// or local shutdown channel. Shutdown only occurs if a signal was received, and any
/// errors received by either shutdown channel are ignored.
macro_rules! await_or_shutdown {
($(@save($id:ident, $tx:ident, $rx:ident))? $future:expr) => {{
let mut f = $future;
loop {
let use_shutdown = match shutdown.try_recv() {
Ok(_) => {
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
Err(broadcast::error::TryRecvError::Empty) => true,
Err(broadcast::error::TryRecvError::Lagged(_)) => true,
Err(broadcast::error::TryRecvError::Closed) => false,
};
let use_local_shutdown = match local_shutdown.try_recv() {
Ok(_) => {
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
Err(oneshot::error::TryRecvError::Empty) => true,
Err(oneshot::error::TryRecvError::Closed) => false,
};
if use_shutdown && use_local_shutdown {
tokio::select! {
x = shutdown.recv() => {
if x.is_err() {
continue;
}
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
x = &mut local_shutdown => {
if x.is_err() {
continue;
}
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
x = &mut f => { break x; }
}
} else if use_shutdown {
tokio::select! {
x = shutdown.recv() => {
if x.is_err() {
continue;
}
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
x = &mut f => { break x; }
}
} else if use_local_shutdown {
tokio::select! {
x = &mut local_shutdown => {
if x.is_err() {
continue;
}
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
x = &mut f => { break x; }
}
} else {
break f.await;
}
}
}};
}
// Attempt to upgrade our handler for use with the connection going forward
let handler = match Weak::upgrade(&handler) {
Some(handler) => handler,
None => {
terminate_connection!(@fatal "Failed to setup connection because handler dropped");
}
};
// Attempt to upgrade our state for use with the connection going forward
let state = match Weak::upgrade(&state) {
Some(state) => state,
None => {
terminate_connection!(@fatal "Failed to setup connection because state dropped");
}
};
// Properly establish the connection's transport // Properly establish the connection's transport
debug!("[Conn {id}] Establishing full connection"); debug!("Establishing full connection using {transport:?}");
let mut connection = match Weak::upgrade(&verifier) { let mut connection = match Weak::upgrade(&verifier) {
Some(verifier) => { Some(verifier) => {
match Connection::server(transport, verifier.as_ref(), keychain).await { match await_or_shutdown!(Box::pin(Connection::server(
transport,
verifier.as_ref(),
keychain
))) {
Ok(connection) => connection, Ok(connection) => connection,
Err(x) => { Err(x) => {
terminate_connection!(@error "[Conn {id}] Failed to setup connection: {x}"); terminate_connection!(@fatal "Failed to setup connection: {x}");
} }
} }
} }
None => { None => {
terminate_connection!(@error "[Conn {id}] Verifier has been dropped"); terminate_connection!(@fatal "Verifier has been dropped");
}
};
// Attempt to upgrade our handler for use with the connection going forward
debug!("[Conn {id}] Preparing connection handler");
let handler = match Weak::upgrade(&handler) {
Some(handler) => handler,
None => {
terminate_connection!(@error "[Conn {id}] Handler has been dropped");
} }
}; };
// Construct a queue of outgoing responses // Update our id to be the connection id
let (tx, mut rx) = mpsc::channel::<Response<H::Response>>(1); let id = connection.id();
// Create local data for the connection and then process it // Create local data for the connection and then process it
debug!("[Conn {id}] Officially accepting connection"); debug!("[Conn {id}] Officially accepting connection");
let mut local_data = H::LocalData::default(); let mut local_data = H::LocalData::default();
if let Err(x) = handler if let Err(x) = await_or_shutdown!(handler.on_accept(ConnectionCtx {
.on_accept(ConnectionCtx { connection_id: id,
connection_id: id, local_data: &mut local_data
local_data: &mut local_data, })) {
}) terminate_connection!(@fatal "[Conn {id}] Accepting connection failed: {x}");
.await
{
terminate_connection!(@error "[Conn {id}] Accepting connection failed: {x}");
} }
let local_data = Arc::new(local_data); let local_data = Arc::new(local_data);
let mut last_heartbeat = Instant::now();
// Restore our connection's channels if we have them, otherwise make new ones
let (tx, mut rx) = match state.connections.write().await.remove(&id) {
Some(conn) => match conn.shutdown_and_wait().await {
Some(x) => {
debug!("[Conn {id}] Marked as existing connection");
x
}
None => {
warn!("[Conn {id}] Existing connection with id, but channels not saved");
mpsc::channel::<Response<H::Response>>(1)
}
},
None => {
debug!("[Conn {id}] Marked as new connection");
mpsc::channel::<Response<H::Response>>(1)
}
};
// Store our connection details
state.connections.write().await.insert(id, connection_state);
debug!("[Conn {id}] Beginning read/write loop"); debug!("[Conn {id}] Beginning read/write loop");
loop { loop {
let ready = match connection let ready = match await_or_shutdown!(
.ready(Interest::READABLE | Interest::WRITABLE) @save(id, tx, rx)
.await Box::pin(connection.ready(Interest::READABLE | Interest::WRITABLE))
{ ) {
Ok(ready) => ready, Ok(ready) => ready,
Err(x) => { Err(x) => {
terminate_connection!(@error "[Conn {id}] Failed to examine ready state: {x}"); terminate_connection!(@error(tx, rx) "[Conn {id}] Failed to examine ready state: {x}");
} }
}; };
@ -311,15 +484,14 @@ where
Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) { Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
Ok(request) => match request.to_typed_request() { Ok(request) => match request.to_typed_request() {
Ok(request) => { Ok(request) => {
let reply = ServerReply { let origin_id = request.id.clone();
origin_id: request.id.clone(),
tx: tx.clone(),
};
let ctx = ServerCtx { let ctx = ServerCtx {
connection_id: id, connection_id: id,
request, request,
reply: reply.clone(), reply: ServerReply {
origin_id,
tx: tx.clone(),
},
local_data: Arc::clone(&local_data), local_data: Arc::clone(&local_data),
}; };
@ -344,11 +516,11 @@ where
} }
}, },
Ok(None) => { Ok(None) => {
terminate_connection!(@debug "[Conn {id}] Connection closed"); terminate_connection!(@debug(tx, rx) "[Conn {id}] Connection closed");
} }
Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true, Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
Err(x) => { Err(x) => {
terminate_connection!(@error "[Conn {id}] {x}"); terminate_connection!(@error(tx, rx) "[Conn {id}] {x}");
} }
} }
} }
@ -356,10 +528,20 @@ where
// If our socket is ready to be written to, we try to get the next item from // If our socket is ready to be written to, we try to get the next item from
// the queue and process it // the queue and process it
if ready.is_writable() { if ready.is_writable() {
// Send a heartbeat if we have exceeded our last time
if last_heartbeat.elapsed() >= heartbeat_duration {
trace!("[Conn {id}] Sending heartbeat via empty frame");
match connection.try_write_frame(Frame::empty()) {
Ok(()) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => error!("[Conn {id}] Send failed: {x}"),
}
last_heartbeat = Instant::now();
}
// If we get more data to write, attempt to write it, which will result in writing // If we get more data to write, attempt to write it, which will result in writing
// any queued bytes as well. Othewise, we attempt to flush any pending outgoing // any queued bytes as well. Othewise, we attempt to flush any pending outgoing
// bytes that weren't sent earlier. // bytes that weren't sent earlier.
if let Ok(response) = rx.try_recv() { else if let Ok(response) = rx.try_recv() {
// Log our message as a string, which can be expensive // Log our message as a string, which can be expensive
if log_enabled!(Level::Trace) { if log_enabled!(Level::Trace) {
trace!( trace!(
@ -541,7 +723,7 @@ mod tests {
let err = task.await.unwrap_err(); let err = task.await.unwrap_err();
assert!( assert!(
err.to_string().contains("Handler has been dropped"), err.to_string().contains("handler dropped"),
"Unexpected error: {err}" "Unexpected error: {err}"
); );
} }
@ -610,6 +792,7 @@ mod tests {
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never))); let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none()); let verifier = Arc::new(Verifier::none());
#[derive(Debug)]
struct FakeTransport { struct FakeTransport {
inner: InmemoryTransport, inner: InmemoryTransport,
fail_ready: Arc<AtomicBool>, fail_ready: Arc<AtomicBool>,
@ -678,7 +861,7 @@ mod tests {
let err = task.await.unwrap_err(); let err = task.await.unwrap_err();
assert!( assert!(
err.to_string().contains("Failed to examine ready state"), err.to_string().contains("targeted ready failure"),
"Unexpected error: {err}" "Unexpected error: {err}"
); );
} }
@ -722,7 +905,7 @@ mod tests {
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never))); let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none()); let verifier = Arc::new(Verifier::none());
ConnectionTask::build() let _conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler)) .handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state)) .state(Arc::downgrade(&state))
.keychain(keychain) .keychain(keychain)
@ -748,4 +931,205 @@ mod tests {
let response = task.await.unwrap(); let response = task.await.unwrap();
assert_eq!(response.payload, "hello"); assert_eq!(response.payload, "hello");
} }
#[test(tokio::test)]
async fn should_send_heartbeat_via_empty_frame_every_minimum_duration() {
let handler = Arc::new(TestServerHandler);
let state = Arc::new(ServerState::default());
let keychain = ServerKeychain::new();
let (t1, t2) = InmemoryTransport::pair(100);
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
let _conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(keychain)
.transport(t1)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.spawn();
// Spawn a task to handle establishing connection from client-side
let task = tokio::spawn(async move {
let mut client = Connection::client(t2, DummyAuthHandler)
.await
.expect("Fail to establish client-side connection");
// Verify we don't get a frame immediately
assert_eq!(
client.try_read_frame().unwrap_err().kind(),
io::ErrorKind::WouldBlock,
"got a frame early"
);
// Sleep more than our minimum heartbeat duration to ensure we get one
tokio::time::sleep(Duration::from_millis(250)).await;
assert_eq!(
client.read_frame().await.unwrap().unwrap(),
Frame::empty(),
"non-empty frame"
);
// Verify we don't get a frame immediately
assert_eq!(
client.try_read_frame().unwrap_err().kind(),
io::ErrorKind::WouldBlock,
"got a frame early"
);
// Sleep more than our minimum heartbeat duration to ensure we get one
tokio::time::sleep(Duration::from_millis(250)).await;
assert_eq!(
client.read_frame().await.unwrap().unwrap(),
Frame::empty(),
"non-empty frame"
);
});
task.await.unwrap();
}
#[test(tokio::test)]
async fn should_be_able_to_shutdown_while_establishing_connection() {
let handler = Arc::new(TestServerHandler);
let state = Arc::new(ServerState::default());
let keychain = ServerKeychain::new();
let (t1, _t2) = InmemoryTransport::pair(100);
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
let conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(keychain)
.transport(t1)
.shutdown(shutdown_rx)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.spawn();
// Shutdown server connection task while it is establishing a full connection with the
// client, verifying that we do not get an error in return
shutdown_tx
.send(())
.expect("Failed to send shutdown signal");
conn.await.unwrap();
}
#[test(tokio::test)]
async fn should_be_able_to_shutdown_while_accepting_connection() {
struct HangingAcceptServerHandler;
#[async_trait]
impl ServerHandler for HangingAcceptServerHandler {
type Request = ();
type Response = ();
type LocalData = ();
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
// Wait "forever" so we can ensure that we fail at this step
tokio::time::sleep(Duration::MAX).await;
Err(io::Error::new(io::ErrorKind::Other, "bad accept"))
}
async fn on_request(
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
unreachable!();
}
}
let handler = Arc::new(HangingAcceptServerHandler);
let state = Arc::new(ServerState::default());
let keychain = ServerKeychain::new();
let (t1, t2) = InmemoryTransport::pair(100);
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
let conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(keychain)
.transport(t1)
.shutdown(shutdown_rx)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.spawn();
// Spawn a task to handle the client-side establishment of a full connection
let _client_task = tokio::spawn(Connection::client(t2, DummyAuthHandler));
// Shutdown server connection task while it is accepting the connection, verifying that we
// do not get an error in return
shutdown_tx
.send(())
.expect("Failed to send shutdown signal");
conn.await.unwrap();
}
#[test(tokio::test)]
async fn should_be_able_to_shutdown_while_waiting_for_connection_to_be_ready() {
struct AcceptServerHandler {
tx: mpsc::Sender<()>,
}
#[async_trait]
impl ServerHandler for AcceptServerHandler {
type Request = ();
type Response = ();
type LocalData = ();
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
self.tx.send(()).await.unwrap();
Ok(())
}
async fn on_request(
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
unreachable!();
}
}
let (tx, mut rx) = mpsc::channel(100);
let handler = Arc::new(AcceptServerHandler { tx });
let state = Arc::new(ServerState::default());
let keychain = ServerKeychain::new();
let (t1, t2) = InmemoryTransport::pair(100);
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
let conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(keychain)
.transport(t1)
.shutdown(shutdown_rx)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.spawn();
// Spawn a task to handle the client-side establishment of a full connection
let _client_task = tokio::spawn(Connection::client(t2, DummyAuthHandler));
// Wait to ensure we complete the accept call first
let _ = rx.recv().await;
// Shutdown server connection task while it is accepting the connection, verifying that we
// do not get an error in return
shutdown_tx
.send(())
.expect("Failed to send shutdown signal");
conn.await.unwrap();
}
} }

@ -1,23 +1,21 @@
use super::ServerState;
use crate::common::AsAny; use crate::common::AsAny;
use log::*;
use std::{ use std::{
future::Future, future::Future,
io, io,
pin::Pin, pin::Pin,
sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
time::Duration, time::Duration,
}; };
use tokio::sync::broadcast;
use tokio::task::{JoinError, JoinHandle}; use tokio::task::{JoinError, JoinHandle};
/// Interface to engage with a server instance /// Interface to engage with a server instance.
pub trait ServerRef: AsAny + Send { pub trait ServerRef: AsAny + Send {
/// Returns true if the server is no longer running /// Returns true if the server is no longer running.
fn is_finished(&self) -> bool; fn is_finished(&self) -> bool;
/// Kills the internal task processing new inbound requests /// Sends a shutdown signal to the server.
fn abort(&self); fn shutdown(&self);
fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>> fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>>
where where
@ -64,7 +62,7 @@ impl dyn ServerRef {
/// Represents a generic reference to a server /// Represents a generic reference to a server
pub struct GenericServerRef { pub struct GenericServerRef {
pub(crate) state: Arc<ServerState>, pub(crate) shutdown: broadcast::Sender<()>,
pub(crate) task: JoinHandle<()>, pub(crate) task: JoinHandle<()>,
} }
@ -74,16 +72,8 @@ impl ServerRef for GenericServerRef {
self.task.is_finished() self.task.is_finished()
} }
fn abort(&self) { fn shutdown(&self) {
self.task.abort(); let _ = self.shutdown.send(());
let state = Arc::clone(&self.state);
tokio::spawn(async move {
for (id, connection) in state.connections.read().await.iter() {
debug!("Aborting connection {}", id);
connection.abort();
}
});
} }
fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>> fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>>

@ -29,7 +29,7 @@ impl ServerRef for TcpServerRef {
self.inner.is_finished() self.inner.is_finished()
} }
fn abort(&self) { fn shutdown(&self) {
self.inner.abort(); self.inner.shutdown();
} }
} }

@ -28,7 +28,7 @@ impl ServerRef for UnixSocketServerRef {
self.inner.is_finished() self.inner.is_finished()
} }
fn abort(&self) { fn shutdown(&self) {
self.inner.abort(); self.inner.shutdown();
} }
} }

@ -28,7 +28,7 @@ impl ServerRef for WindowsPipeServerRef {
self.inner.is_finished() self.inner.is_finished()
} }
fn abort(&self) { fn shutdown(&self) {
self.inner.abort(); self.inner.shutdown();
} }
} }

@ -1,37 +1,70 @@
use super::ConnectionTask;
use crate::common::{authentication::Keychain, Backup, ConnectionId}; use crate::common::{authentication::Keychain, Backup, ConnectionId};
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::{oneshot, RwLock}; use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::task::JoinHandle;
/// Contains all top-level state for the server /// Contains all top-level state for the server
pub struct ServerState { pub struct ServerState<T> {
/// Mapping of connection ids to their transports /// Mapping of connection ids to their tasks.
pub connections: RwLock<HashMap<ConnectionId, ConnectionTask>>, pub connections: RwLock<HashMap<ConnectionId, ConnectionState<T>>>,
/// Mapping of connection ids to (OTP, backup) /// Mapping of connection ids to (OTP, backup)
pub keychain: Keychain<oneshot::Receiver<Backup>>, pub keychain: Keychain<oneshot::Receiver<Backup>>,
} }
impl ServerState { impl<T> ServerState<T> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
connections: RwLock::new(HashMap::new()), connections: RwLock::new(HashMap::new()),
keychain: Keychain::new(), keychain: Keychain::new(),
} }
} }
/// Returns true if there is at least one active connection
pub async fn has_active_connections(&self) -> bool {
self.connections
.read()
.await
.values()
.any(|task| !task.is_finished())
}
} }
impl Default for ServerState { impl<T> Default for ServerState<T> {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
} }
pub struct ConnectionState<T> {
shutdown_tx: oneshot::Sender<()>,
task: JoinHandle<Option<(mpsc::Sender<T>, mpsc::Receiver<T>)>>,
}
impl<T: Send + 'static> ConnectionState<T> {
/// Creates new state with appropriate channels, returning
/// (shutdown receiver, channel sender, state).
#[allow(clippy::type_complexity)]
pub fn channel() -> (
oneshot::Receiver<()>,
oneshot::Sender<(mpsc::Sender<T>, mpsc::Receiver<T>)>,
Self,
) {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (channel_tx, channel_rx) = oneshot::channel();
(
shutdown_rx,
channel_tx,
Self {
shutdown_tx,
task: tokio::spawn(async move {
match channel_rx.await {
Ok(x) => Some(x),
Err(_) => None,
}
}),
},
)
}
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
pub async fn shutdown_and_wait(self) -> Option<(mpsc::Sender<T>, mpsc::Receiver<T>)> {
let _ = self.shutdown_tx.send(());
self.task.await.unwrap()
}
}

@ -1,6 +1,6 @@
use async_trait::async_trait; use async_trait::async_trait;
use distant_net::boxed_connect_handler; use distant_net::boxed_connect_handler;
use distant_net::client::{Client, ReconnectStrategy}; use distant_net::client::Client;
use distant_net::common::authentication::{DummyAuthHandler, Verifier}; use distant_net::common::authentication::{DummyAuthHandler, Verifier};
use distant_net::common::{Destination, InmemoryTransport, Map, OneshotListener}; use distant_net::common::{Destination, InmemoryTransport, Map, OneshotListener};
use distant_net::manager::{Config, ManagerClient, ManagerServer}; use distant_net::manager::{Config, ManagerClient, ManagerServer};
@ -43,7 +43,6 @@ async fn should_be_able_to_establish_a_single_connection_and_communicate_with_a_
let client = Client::build() let client = Client::build()
.auth_handler(DummyAuthHandler) .auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.connector(t1) .connector(t1)
.connect_untyped() .connect_untyped()
.await?; .await?;
@ -61,7 +60,6 @@ async fn should_be_able_to_establish_a_single_connection_and_communicate_with_a_
info!("Connecting to manager"); info!("Connecting to manager");
let mut client: ManagerClient = Client::build() let mut client: ManagerClient = Client::build()
.auth_handler(DummyAuthHandler) .auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.connector(t1) .connector(t1)
.connect() .connect()
.await .await

@ -1,5 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use distant_net::client::{Client, ReconnectStrategy}; use distant_net::client::Client;
use distant_net::common::authentication::{DummyAuthHandler, Verifier}; use distant_net::common::authentication::{DummyAuthHandler, Verifier};
use distant_net::common::{InmemoryTransport, OneshotListener}; use distant_net::common::{InmemoryTransport, OneshotListener};
use distant_net::server::{Server, ServerCtx, ServerHandler}; use distant_net::server::{Server, ServerCtx, ServerHandler};
@ -38,7 +38,6 @@ async fn should_be_able_to_send_and_receive_typed_payloads_between_client_and_se
let mut client: Client<(u8, String), String> = Client::build() let mut client: Client<(u8, String), String> = Client::build()
.auth_handler(DummyAuthHandler) .auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.connector(t1) .connector(t1)
.connect() .connect()
.await .await

@ -1,5 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use distant_net::client::{Client, ReconnectStrategy}; use distant_net::client::Client;
use distant_net::common::authentication::{DummyAuthHandler, Verifier}; use distant_net::common::authentication::{DummyAuthHandler, Verifier};
use distant_net::common::{InmemoryTransport, OneshotListener, Request}; use distant_net::common::{InmemoryTransport, OneshotListener, Request};
use distant_net::server::{Server, ServerCtx, ServerHandler}; use distant_net::server::{Server, ServerCtx, ServerHandler};
@ -38,7 +38,6 @@ async fn should_be_able_to_send_and_receive_untyped_payloads_between_client_and_
let mut client = Client::build() let mut client = Client::build()
.auth_handler(DummyAuthHandler) .auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.connector(t1) .connector(t1)
.connect_untyped() .connect_untyped()
.await .await

@ -2,7 +2,7 @@
name = "distant-ssh2" name = "distant-ssh2"
description = "Library to enable native ssh-2 protocol for use with distant sessions" description = "Library to enable native ssh-2 protocol for use with distant sessions"
categories = ["network-programming"] categories = ["network-programming"]
version = "0.20.0-alpha.2" version = "0.20.0-alpha.3"
authors = ["Chip Senkbeil <chip@senkbeil.org>"] authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021" edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant" homepage = "https://github.com/chipsenkbeil/distant"
@ -20,7 +20,7 @@ async-compat = "0.2.1"
async-once-cell = "0.4.2" async-once-cell = "0.4.2"
async-trait = "0.1.58" async-trait = "0.1.58"
derive_more = { version = "0.99.17", default-features = false, features = ["display", "error"] } derive_more = { version = "0.99.17", default-features = false, features = ["display", "error"] }
distant-core = { version = "=0.20.0-alpha.2", path = "../distant-core" } distant-core = { version = "=0.20.0-alpha.3", path = "../distant-core" }
futures = "0.3.25" futures = "0.3.25"
hex = "0.4.3" hex = "0.4.3"
log = "0.4.17" log = "0.4.17"

@ -694,12 +694,11 @@ impl DistantApi for SshDistantApi {
cmd: String, cmd: String,
environment: Environment, environment: Environment,
current_dir: Option<PathBuf>, current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>, pty: Option<PtySize>,
) -> io::Result<ProcessId> { ) -> io::Result<ProcessId> {
debug!( debug!(
"[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, persist: {}, pty: {:?}}}", "[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, pty: {:?}}}",
ctx.connection_id, cmd, environment, current_dir, persist, pty ctx.connection_id, cmd, environment, current_dir, pty
); );
let global_processes = Arc::downgrade(&self.processes); let global_processes = Arc::downgrade(&self.processes);
@ -744,12 +743,6 @@ impl DistantApi for SshDistantApi {
} }
}; };
// If the process will be killed when the connection ends, we want to add it
// to our local data
if !persist {
ctx.local_data.processes.write().await.insert(id);
}
self.processes.write().await.insert( self.processes.write().await.insert(
id, id,
Process { Process {

@ -7,7 +7,7 @@ use async_trait::async_trait;
use distant_core::{ use distant_core::{
data::Environment, data::Environment,
net::{ net::{
client::{Client, ReconnectStrategy}, client::{Client, ClientConfig},
common::authentication::{AuthHandlerMap, DummyAuthHandler, Verifier}, common::authentication::{AuthHandlerMap, DummyAuthHandler, Verifier},
common::{InmemoryTransport, OneshotListener}, common::{InmemoryTransport, OneshotListener},
server::{Server, ServerRef}, server::{Server, ServerRef},
@ -574,7 +574,7 @@ impl Ssh {
debug!("Attempting to connect to distant server @ {}", addr); debug!("Attempting to connect to distant server @ {}", addr);
match Client::tcp(addr) match Client::tcp(addr)
.auth_handler(AuthHandlerMap::new().with_static_key(key.clone())) .auth_handler(AuthHandlerMap::new().with_static_key(key.clone()))
.timeout(timeout) .connect_timeout(timeout)
.connect() .connect()
.await .await
{ {
@ -646,7 +646,7 @@ impl Ssh {
); );
// Close out ssh client by killing the internal server and client // Close out ssh client by killing the internal server and client
server.abort(); server.shutdown();
client.abort(); client.abort();
let _ = client.wait().await; let _ = client.wait().await;
@ -718,8 +718,8 @@ impl Ssh {
.start(OneshotListener::from_value(t2))?; .start(OneshotListener::from_value(t2))?;
let client = Client::build() let client = Client::build()
.auth_handler(DummyAuthHandler) .auth_handler(DummyAuthHandler)
.config(ClientConfig::default().with_maximum_silence_duration())
.connector(t1) .connector(t1)
.reconnect_strategy(ReconnectStrategy::Fail)
.connect() .connect()
.await?; .await?;
Ok((client, server)) Ok((client, server))

@ -1217,7 +1217,6 @@ async fn proc_spawn_should_not_fail_even_if_process_not_found(
/* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), /* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1239,7 +1238,6 @@ async fn proc_spawn_should_return_id_of_spawned_process(#[future] client: Ctx<Di
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1267,7 +1265,6 @@ async fn proc_spawn_should_send_back_stdout_periodically_when_available(
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1303,7 +1300,6 @@ async fn proc_spawn_should_send_back_stderr_periodically_when_available(
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1333,7 +1329,6 @@ async fn proc_spawn_should_send_done_signal_when_completed(#[future] client: Ctx
format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1355,7 +1350,6 @@ async fn proc_spawn_should_clear_process_from_state_when_killed(
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1380,7 +1374,6 @@ async fn proc_kill_should_fail_if_process_not_running(#[future] client: Ctx<Dist
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1408,7 +1401,6 @@ async fn proc_stdin_should_fail_if_process_not_running(#[future] client: Ctx<Dis
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1444,7 +1436,6 @@ async fn proc_stdin_should_send_stdin_to_process(#[future] client: Ctx<DistantCl
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await

@ -1199,7 +1199,6 @@ async fn proc_spawn_should_fail_if_process_not_found(
/* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), /* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1223,7 +1222,6 @@ async fn proc_spawn_should_return_id_of_spawned_process(
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1251,7 +1249,6 @@ async fn proc_spawn_should_send_back_stdout_periodically_when_available(
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1287,7 +1284,6 @@ async fn proc_spawn_should_send_back_stderr_periodically_when_available(
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1319,7 +1315,6 @@ async fn proc_spawn_should_send_done_signal_when_completed(
format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1341,7 +1336,6 @@ async fn proc_spawn_should_clear_process_from_state_when_killed(
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1368,7 +1362,6 @@ async fn proc_kill_should_fail_if_process_not_running(
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1398,7 +1391,6 @@ async fn proc_stdin_should_fail_if_process_not_running(
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()), format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await
@ -1434,7 +1426,6 @@ async fn proc_stdin_should_send_stdin_to_process(#[future] launched_client: Ctx<
), ),
/* environment */ Environment::new(), /* environment */ Environment::new(),
/* current_dir */ None, /* current_dir */ None,
/* persist */ false,
/* pty */ None, /* pty */ None,
) )
.await .await

@ -1,6 +1,6 @@
use crate::config::NetworkConfig; use crate::config::NetworkConfig;
use async_trait::async_trait; use async_trait::async_trait;
use distant_core::net::client::{Client as NetClient, ReconnectStrategy}; use distant_core::net::client::{Client as NetClient, ClientConfig, ReconnectStrategy};
use distant_core::net::common::authentication::msg::*; use distant_core::net::common::authentication::msg::*;
use distant_core::net::common::authentication::{ use distant_core::net::common::authentication::{
AuthHandler, AuthMethodHandler, PromptAuthMethodHandler, SingleAuthHandler, AuthHandler, AuthMethodHandler, PromptAuthMethodHandler, SingleAuthHandler,
@ -61,12 +61,15 @@ impl<T: AuthHandler + Clone> Client<T> {
for path in self.network.to_unix_socket_path_candidates() { for path in self.network.to_unix_socket_path_candidates() {
match NetClient::unix_socket(path) match NetClient::unix_socket(path)
.auth_handler(self.auth_handler.clone()) .auth_handler(self.auth_handler.clone())
.reconnect_strategy(ReconnectStrategy::ExponentialBackoff { .config(ClientConfig {
base: Duration::from_secs(1), reconnect_strategy: ReconnectStrategy::ExponentialBackoff {
factor: 2.0, base: Duration::from_secs(1),
max_duration: None, factor: 2.0,
max_retries: None, max_duration: Some(Duration::from_secs(10)),
timeout: None, max_retries: None,
timeout: None,
},
..Default::default()
}) })
.connect() .connect()
.await .await
@ -100,12 +103,15 @@ impl<T: AuthHandler + Clone> Client<T> {
for name in self.network.to_windows_pipe_name_candidates() { for name in self.network.to_windows_pipe_name_candidates() {
match NetClient::local_windows_pipe(name) match NetClient::local_windows_pipe(name)
.auth_handler(self.auth_handler.clone()) .auth_handler(self.auth_handler.clone())
.reconnect_strategy(ReconnectStrategy::ExponentialBackoff { .config(ClientConfig {
base: Duration::from_secs(1), reconnect_strategy: ReconnectStrategy::ExponentialBackoff {
factor: 2.0, base: Duration::from_secs(1),
max_duration: None, factor: 2.0,
max_retries: None, max_duration: Some(Duration::from_secs(10)),
timeout: None, max_retries: None,
timeout: None,
},
..Default::default()
}) })
.connect() .connect()
.await .await

@ -134,11 +134,6 @@ pub enum ClientSubcommand {
#[clap(flatten)] #[clap(flatten)]
network: NetworkConfig, network: NetworkConfig,
/// If provided, will run in persist mode, meaning that the process will not be killed if the
/// client disconnects from the server
#[clap(long)]
persist: bool,
/// If provided, will run LSP in a pty /// If provided, will run LSP in a pty
#[clap(long)] #[clap(long)]
pty: bool, pty: bool,
@ -215,11 +210,6 @@ pub enum ClientSubcommand {
#[clap(long, default_value_t)] #[clap(long, default_value_t)]
environment: Environment, environment: Environment,
/// If provided, will run in persist mode, meaning that the process will not be killed if the
/// client disconnects from the server
#[clap(long)]
persist: bool,
/// Optional command to run instead of $SHELL /// Optional command to run instead of $SHELL
cmd: Option<String>, cmd: Option<String>,
}, },
@ -294,14 +284,12 @@ impl ClientSubcommand {
cmd, cmd,
environment, environment,
current_dir, current_dir,
persist,
pty, pty,
} => { } => {
debug!("Special request spawning {:?}", cmd); debug!("Special request spawning {:?}", cmd);
let mut proc = RemoteCommand::new() let mut proc = RemoteCommand::new()
.environment(environment) .environment(environment)
.current_dir(current_dir) .current_dir(current_dir)
.persist(persist)
.pty(pty) .pty(pty)
.spawn(channel.into_client().into_channel(), cmd.as_str()) .spawn(channel.into_client().into_channel(), cmd.as_str())
.await .await
@ -568,7 +556,6 @@ impl ClientSubcommand {
Self::Lsp { Self::Lsp {
connection, connection,
network, network,
persist,
pty, pty,
cmd, cmd,
.. ..
@ -592,12 +579,9 @@ impl ClientSubcommand {
format!("Failed to open channel to connection {connection_id}") format!("Failed to open channel to connection {connection_id}")
})?; })?;
debug!( debug!("Spawning LSP server (pty = {}): {}", pty, cmd);
"Spawning LSP server (persist = {}, pty = {}): {}",
persist, pty, cmd
);
Lsp::new(channel.into_client().into_channel()) Lsp::new(channel.into_client().into_channel())
.spawn(cmd, persist, pty) .spawn(cmd, pty)
.await?; .await?;
} }
Self::Repl { Self::Repl {
@ -869,7 +853,6 @@ impl ClientSubcommand {
connection, connection,
network, network,
environment, environment,
persist,
cmd, cmd,
.. ..
} => { } => {
@ -893,13 +876,12 @@ impl ClientSubcommand {
})?; })?;
debug!( debug!(
"Spawning shell (environment = {:?}, persist = {}): {}", "Spawning shell (environment = {:?}): {}",
environment, environment,
persist,
cmd.as_deref().unwrap_or(r"$SHELL") cmd.as_deref().unwrap_or(r"$SHELL")
); );
Shell::new(channel.into_client().into_channel()) Shell::new(channel.into_client().into_channel())
.spawn(cmd, environment, persist) .spawn(cmd, environment)
.await?; .await?;
} }
} }

@ -11,10 +11,9 @@ impl Lsp {
Self(channel) Self(channel)
} }
pub async fn spawn(self, cmd: impl Into<String>, persist: bool, pty: bool) -> CliResult { pub async fn spawn(self, cmd: impl Into<String>, pty: bool) -> CliResult {
let cmd = cmd.into(); let cmd = cmd.into();
let mut proc = RemoteLspCommand::new() let mut proc = RemoteLspCommand::new()
.persist(persist)
.pty(if pty { .pty(if pty {
terminal_size().map(|(Width(width), Height(height))| { terminal_size().map(|(Width(width), Height(height))| {
PtySize::from_rows_and_cols(height, width) PtySize::from_rows_and_cols(height, width)

@ -25,7 +25,6 @@ impl Shell {
mut self, mut self,
cmd: impl Into<Option<String>>, cmd: impl Into<Option<String>>,
mut environment: Environment, mut environment: Environment,
persist: bool,
) -> CliResult { ) -> CliResult {
// Automatically add TERM=xterm-256color if not specified // Automatically add TERM=xterm-256color if not specified
if !environment.contains_key("TERM") { if !environment.contains_key("TERM") {
@ -55,7 +54,6 @@ impl Shell {
}; };
let mut proc = RemoteCommand::new() let mut proc = RemoteCommand::new()
.persist(persist)
.environment(environment) .environment(environment)
.pty( .pty(
terminal_size() terminal_size()

@ -1,6 +1,6 @@
use crate::config::ClientLaunchConfig; use crate::config::ClientLaunchConfig;
use async_trait::async_trait; use async_trait::async_trait;
use distant_core::net::client::{Client, ReconnectStrategy, UntypedClient}; use distant_core::net::client::{Client, ClientConfig, ReconnectStrategy, UntypedClient};
use distant_core::net::common::authentication::msg::*; use distant_core::net::common::authentication::msg::*;
use distant_core::net::common::authentication::{ use distant_core::net::common::authentication::{
AuthHandler, Authenticator, DynAuthHandler, ProxyAuthHandler, SingleAuthHandler, AuthHandler, Authenticator, DynAuthHandler, ProxyAuthHandler, SingleAuthHandler,
@ -210,14 +210,17 @@ impl DistantConnectHandler {
match Client::tcp(addr) match Client::tcp(addr)
.auth_handler(DynAuthHandler::from(&mut auth_handler)) .auth_handler(DynAuthHandler::from(&mut auth_handler))
.reconnect_strategy(ReconnectStrategy::ExponentialBackoff { .config(ClientConfig {
base: Duration::from_secs(1), reconnect_strategy: ReconnectStrategy::ExponentialBackoff {
factor: 2.0, base: Duration::from_secs(1),
max_duration: None, factor: 2.0,
max_retries: None, max_duration: Some(Duration::from_secs(10)),
timeout: None, max_retries: None,
timeout: None,
},
..Default::default()
}) })
.timeout(Duration::from_secs(180)) .connect_timeout(Duration::from_secs(180))
.connect_untyped() .connect_untyped()
.await .await
{ {

@ -84,7 +84,6 @@ async fn should_support_json_to_execute_program_and_return_exit_status(
"payload": { "payload": {
"type": "proc_spawn", "type": "proc_spawn",
"cmd": cmd, "cmd": cmd,
"persist": false,
"pty": null, "pty": null,
}, },
}); });
@ -109,7 +108,6 @@ async fn should_support_json_to_capture_and_print_stdout(mut json_repl: CtxComma
"payload": { "payload": {
"type": "proc_spawn", "type": "proc_spawn",
"cmd": cmd, "cmd": cmd,
"persist": false,
"pty": null, "pty": null,
}, },
}); });
@ -148,7 +146,6 @@ async fn should_support_json_to_capture_and_print_stderr(mut json_repl: CtxComma
"payload": { "payload": {
"type": "proc_spawn", "type": "proc_spawn",
"cmd": cmd, "cmd": cmd,
"persist": false,
"pty": null, "pty": null,
}, },
}); });
@ -187,7 +184,6 @@ async fn should_support_json_to_forward_stdin_to_remote_process(mut json_repl: C
"payload": { "payload": {
"type": "proc_spawn", "type": "proc_spawn",
"cmd": cmd, "cmd": cmd,
"persist": false,
"pty": null, "pty": null,
}, },
}); });
@ -271,7 +267,6 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand<Repl>) {
"payload": { "payload": {
"type": "proc_spawn", "type": "proc_spawn",
"cmd": DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(), "cmd": DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(),
"persist": false,
"pty": null, "pty": null,
}, },
}); });

Loading…
Cancel
Save