Add heartbeat support (#153)

* Update to support zero-size frame items

* Add heartbeat functionality with client reconnecting logic

* Fix connection reauthentication failures preventing future reauthentication

* More logging

* Remove persist

* Update connection logic to have server take on client id rather than having client take on server id during reconnect

* Bump minimum rust version to 1.64.0

* Bump to v0.20.0-alpha.3 and fix clippy warnings

* Update cargo.lock
pull/156/head v0.20.0-alpha.3
Chip Senkbeil 1 year ago committed by GitHub
parent ee595551ae
commit ee50eaf9b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -86,7 +86,7 @@ jobs:
- { rust: stable, os: windows-latest, target: x86_64-pc-windows-msvc }
- { rust: stable, os: macos-latest }
- { rust: stable, os: ubuntu-latest }
- { rust: 1.61.0, os: ubuntu-latest }
- { rust: 1.64.0, os: ubuntu-latest }
steps:
- uses: actions/checkout@v3
- name: Install Rust ${{ matrix.rust }}

@ -7,6 +7,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.20.0-alpha.3]
### Added
- `Frame::empty` method as convenience for `Frame::new(&[])`
- `ClientConfig` to support `ReconnectStrategy` and a duration serving as the
maximum time to wait between server activity before attempting to reconnect
from the client
- Server sends empty frames periodically to act as heartbeats to let the client
know if the connection is still established
- Client now tracks length of time since last server activity and will attempt
a reconnect if no activity beyond that point
### Changed
- `Frame` methods `read` and `write` no longer return an `io::Result<...>`
and instead return `Option<Frame<...>>` and nothing respectively
- `Frame::read` method now supports zero-size items
- `Client::inmemory_spawn` and `UntypedClient::inmemory_spawn` now take a
`ClientConfig` as the second argument instead of `ReconnectStrategy`
- Persist option now removed from `ProcSpawn` message and CLI
- Bump minimum Rust version to 1.64.0
## [0.20.0-alpha.2] - 2022-11-20
### Added

593
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -3,7 +3,7 @@ name = "distant"
description = "Operate on a remote computer through file and process manipulation"
categories = ["command-line-utilities"]
keywords = ["cli"]
version = "0.20.0-alpha.2"
version = "0.20.0-alpha.3"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"
@ -32,7 +32,7 @@ clap_complete = "4.0.5"
config = { version = "0.13.2", default-features = false, features = ["toml"] }
derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error", "is_variant"] }
dialoguer = { version = "0.10.2", default-features = false }
distant-core = { version = "=0.20.0-alpha.2", path = "distant-core", features = ["clap", "schemars"] }
distant-core = { version = "=0.20.0-alpha.3", path = "distant-core", features = ["clap", "schemars"] }
directories = "4.0.1"
flexi_logger = "0.24.1"
indoc = "1.0.7"
@ -54,7 +54,7 @@ winsplit = "0.1.0"
whoami = "1.2.3"
# Optional native SSH functionality
distant-ssh2 = { version = "=0.20.0-alpha.2", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true }
distant-ssh2 = { version = "=0.20.0-alpha.3", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true }
[target.'cfg(unix)'.dependencies]
fork = "0.1.20"

@ -3,7 +3,7 @@ name = "distant-core"
description = "Core library for distant, enabling operation on a remote computer through file and process manipulation"
categories = ["network-programming"]
keywords = ["api", "async"]
version = "0.20.0-alpha.2"
version = "0.20.0-alpha.3"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"
@ -19,7 +19,7 @@ async-trait = "0.1.58"
bitflags = "1.3.2"
bytes = "1.2.1"
derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] }
distant-net = { version = "=0.20.0-alpha.2", path = "../distant-net" }
distant-net = { version = "=0.20.0-alpha.3", path = "../distant-net" }
futures = "0.3.25"
grep = "0.2.10"
hex = "0.4.3"

@ -348,8 +348,6 @@ pub trait DistantApi {
/// * `cmd` - the full command to run as a new process (including arguments)
/// * `environment` - the environment variables to associate with the process
/// * `current_dir` - the alternative current directory to use with the process
/// * `persist` - if true, the process will continue running even after the connection that
/// spawned the process has terminated
/// * `pty` - if provided, will run the process within a PTY of the given size
///
/// *Override this, otherwise it will return "unsupported" as an error.*
@ -360,7 +358,6 @@ pub trait DistantApi {
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> io::Result<ProcessId> {
unsupported("proc_spawn")
@ -650,11 +647,10 @@ where
cmd,
environment,
current_dir,
persist,
pty,
} => server
.api
.proc_spawn(ctx, cmd.into(), environment, current_dir, persist, pty)
.proc_spawn(ctx, cmd.into(), environment, current_dir, pty)
.await
.map(|id| DistantResponseData::ProcSpawned { id })
.unwrap_or_else(DistantResponseData::from),

@ -6,7 +6,6 @@ use crate::{
DistantApi, DistantCtx,
};
use async_trait::async_trait;
use distant_net::server::ConnectionCtx;
use log::*;
use std::{
io,
@ -18,7 +17,6 @@ use walkdir::WalkDir;
mod process;
mod state;
pub use state::ConnectionState;
use state::*;
/// Represents an implementation of [`DistantApi`] that works with the local machine
@ -40,14 +38,7 @@ impl LocalDistantApi {
#[async_trait]
impl DistantApi for LocalDistantApi {
type LocalData = ConnectionState;
/// Injects the global channels into the local connection
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
ctx.local_data.process_channel = self.state.process.clone_channel();
ctx.local_data.watcher_channel = self.state.watcher.clone_channel();
Ok(())
}
type LocalData = ();
async fn capabilities(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Capabilities> {
debug!("[Conn {}] Querying capabilities", ctx.connection_id);
@ -451,16 +442,15 @@ impl DistantApi for LocalDistantApi {
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> io::Result<ProcessId> {
debug!(
"[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, persist: {}, pty: {:?}}}",
ctx.connection_id, cmd, environment, current_dir, persist, pty
"[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, pty: {:?}}}",
ctx.connection_id, cmd, environment, current_dir, pty
);
self.state
.process
.spawn(cmd, environment, current_dir, persist, pty, ctx.reply)
.spawn(cmd, environment, current_dir, pty, ctx.reply)
.await
}
@ -504,6 +494,7 @@ impl DistantApi for LocalDistantApi {
#[cfg(test)]
mod tests {
use super::*;
use crate::api::ConnectionCtx;
use crate::data::DistantResponseData;
use assert_fs::prelude::*;
use distant_net::server::Reply;
@ -576,18 +567,18 @@ mod tests {
buffer: usize,
) -> (
LocalDistantApi,
DistantCtx<ConnectionState>,
DistantCtx<()>,
mpsc::Receiver<DistantResponseData>,
) {
let api = LocalDistantApi::initialize().unwrap();
let (reply, rx) = make_reply(buffer);
let connection_id = rand::random();
let mut local_data = ConnectionState::default();
DistantApi::on_accept(
&api,
ConnectionCtx {
connection_id,
local_data: &mut local_data,
local_data: &mut (),
},
)
.await
@ -595,7 +586,7 @@ mod tests {
let ctx = DistantCtx {
connection_id,
reply,
local_data: Arc::new(local_data),
local_data: Arc::new(()),
};
(api, ctx, rx)
}
@ -1842,7 +1833,6 @@ mod tests {
/* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1867,7 +1857,6 @@ mod tests {
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1893,7 +1882,6 @@ mod tests {
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1958,7 +1946,6 @@ mod tests {
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -2019,7 +2006,6 @@ mod tests {
format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -2059,7 +2045,6 @@ mod tests {
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -2126,7 +2111,6 @@ mod tests {
),
Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await

@ -1,6 +1,4 @@
use crate::data::{ProcessId, SearchId};
use distant_net::common::ConnectionId;
use std::{io, path::PathBuf};
use std::io;
mod process;
pub use process::*;
@ -32,57 +30,3 @@ impl GlobalState {
})
}
}
/// Holds connection-specific state managed by the server
#[derive(Default)]
pub struct ConnectionState {
/// Unique id associated with connection
id: ConnectionId,
/// Channel connected to global process state
pub(crate) process_channel: ProcessChannel,
/// Channel connected to global search state
pub(crate) search_channel: SearchChannel,
/// Channel connected to global watcher state
pub(crate) watcher_channel: WatcherChannel,
/// Contains ids of processes that will be terminated when the connection is closed
processes: Vec<ProcessId>,
/// Contains paths being watched that will be unwatched when the connection is closed
paths: Vec<PathBuf>,
/// Contains ids of searches that will be terminated when the connection is closed
searches: Vec<SearchId>,
}
impl Drop for ConnectionState {
fn drop(&mut self) {
let id = self.id;
let processes: Vec<ProcessId> = self.processes.drain(..).collect();
let paths: Vec<PathBuf> = self.paths.drain(..).collect();
let searches: Vec<SearchId> = self.searches.drain(..).collect();
let process_channel = self.process_channel.clone();
let search_channel = self.search_channel.clone();
let watcher_channel = self.watcher_channel.clone();
// NOTE: We cannot (and should not) block during drop to perform cleanup,
// instead spawning a task that will do the cleanup async
tokio::spawn(async move {
for id in processes {
let _ = process_channel.kill(id).await;
}
for id in searches {
let _ = search_channel.cancel(id).await;
}
for path in paths {
let _ = watcher_channel.unwatch(id, path).await;
}
});
}
}

@ -9,14 +9,14 @@ use tokio::{
mod instance;
pub use instance::*;
/// Holds information related to spawned processes on the server
/// Holds information related to spawned processes on the server.
pub struct ProcessState {
channel: ProcessChannel,
task: JoinHandle<()>,
}
impl Drop for ProcessState {
/// Aborts the task that handles process operations and management
/// Aborts the task that handles process operations and management.
fn drop(&mut self) {
self.abort();
}
@ -33,10 +33,6 @@ impl ProcessState {
}
}
pub fn clone_channel(&self) -> ProcessChannel {
self.channel.clone()
}
/// Aborts the process task
pub fn abort(&self) {
self.task.abort();
@ -57,7 +53,7 @@ pub struct ProcessChannel {
}
impl Default for ProcessChannel {
/// Creates a new channel that is closed by default
/// Creates a new channel that is closed by default.
fn default() -> Self {
let (tx, _) = mpsc::channel(1);
Self { tx }
@ -65,13 +61,12 @@ impl Default for ProcessChannel {
}
impl ProcessChannel {
/// Spawns a new process, returning the id associated with it
/// Spawns a new process, returning the id associated with it.
pub async fn spawn(
&self,
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<ProcessId> {
@ -81,7 +76,6 @@ impl ProcessChannel {
cmd,
environment,
current_dir,
persist,
pty,
reply,
cb,
@ -92,7 +86,7 @@ impl ProcessChannel {
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to spawn dropped"))?
}
/// Resizes the pty of a running process
/// Resizes the pty of a running process.
pub async fn resize_pty(&self, id: ProcessId, size: PtySize) -> io::Result<()> {
let (cb, rx) = oneshot::channel();
self.tx
@ -103,7 +97,7 @@ impl ProcessChannel {
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to resize dropped"))?
}
/// Send stdin to a running process
/// Send stdin to a running process.
pub async fn send_stdin(&self, id: ProcessId, data: Vec<u8>) -> io::Result<()> {
let (cb, rx) = oneshot::channel();
self.tx
@ -114,7 +108,8 @@ impl ProcessChannel {
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Response to stdin dropped"))?
}
/// Kills a running process
/// Kills a running process, including persistent processes if `force` is true. Will fail if
/// unable to kill the process or `force` is false when the process is persistent.
pub async fn kill(&self, id: ProcessId) -> io::Result<()> {
let (cb, rx) = oneshot::channel();
self.tx
@ -126,13 +121,12 @@ impl ProcessChannel {
}
}
/// Internal message to pass to our task below to perform some action
/// Internal message to pass to our task below to perform some action.
enum InnerProcessMsg {
Spawn {
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
cb: oneshot::Sender<io::Result<ProcessId>>,
@ -165,14 +159,12 @@ async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<
cmd,
environment,
current_dir,
persist,
pty,
reply,
cb,
} => {
let _ = cb.send(
match ProcessInstance::spawn(cmd, environment, current_dir, persist, pty, reply)
{
match ProcessInstance::spawn(cmd, environment, current_dir, pty, reply) {
Ok(mut process) => {
let id = process.id;
@ -195,7 +187,7 @@ async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<
Some(process) => process.pty.resize_pty(size),
None => Err(io::Error::new(
io::ErrorKind::Other,
format!("No process found with id {}", id),
format!("No process found with id {id}"),
)),
});
}
@ -205,12 +197,12 @@ async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<
Some(stdin) => stdin.send(&data).await,
None => Err(io::Error::new(
io::ErrorKind::Other,
format!("Process {} stdin is closed", id),
format!("Process {id} stdin is closed"),
)),
},
None => Err(io::Error::new(
io::ErrorKind::Other,
format!("No process found with id {}", id),
format!("No process found with id {id}"),
)),
});
}
@ -219,7 +211,7 @@ async fn process_task(tx: mpsc::Sender<InnerProcessMsg>, mut rx: mpsc::Receiver<
Some(process) => process.killer.kill().await,
None => Err(io::Error::new(
io::ErrorKind::Other,
format!("No process found with id {}", id),
format!("No process found with id {id}"),
)),
});
}

@ -13,7 +13,6 @@ use tokio::task::JoinHandle;
pub struct ProcessInstance {
pub cmd: String,
pub args: Vec<String>,
pub persist: bool,
pub id: ProcessId,
pub stdin: Option<Box<dyn InputChannel>>,
@ -63,7 +62,6 @@ impl ProcessInstance {
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
reply: Box<dyn Reply<Data = DistantResponseData>>,
) -> io::Result<Self> {
@ -135,7 +133,6 @@ impl ProcessInstance {
Ok(ProcessInstance {
cmd,
args,
persist,
id,
stdin,
killer,

@ -95,10 +95,6 @@ impl WatcherState {
}
}
pub fn clone_channel(&self) -> WatcherChannel {
self.channel.clone()
}
/// Aborts the watcher task
pub fn abort(&self) {
self.task.abort();

@ -101,7 +101,6 @@ pub trait DistantChannelExt {
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteProcess>;
@ -111,7 +110,6 @@ pub trait DistantChannelExt {
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteLspProcess>;
@ -369,7 +367,6 @@ impl DistantChannelExt
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteProcess> {
let cmd = cmd.into();
@ -377,7 +374,6 @@ impl DistantChannelExt
RemoteCommand::new()
.environment(environment)
.current_dir(current_dir)
.persist(persist)
.pty(pty)
.spawn(self.clone(), cmd)
.await
@ -389,7 +385,6 @@ impl DistantChannelExt
cmd: impl Into<String>,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> AsyncReturn<'_, RemoteLspProcess> {
let cmd = cmd.into();
@ -397,7 +392,6 @@ impl DistantChannelExt
RemoteLspCommand::new()
.environment(environment)
.current_dir(current_dir)
.persist(persist)
.pty(pty)
.spawn(self.clone(), cmd)
.await
@ -416,7 +410,6 @@ impl DistantChannelExt
RemoteCommand::new()
.environment(environment)
.current_dir(current_dir)
.persist(false)
.pty(pty)
.spawn(self.clone(), cmd)
.await?

@ -22,7 +22,6 @@ pub use msg::*;
/// A [`RemoteLspProcess`] builder providing support to configure
/// before spawning the process on a remote machine
pub struct RemoteLspCommand {
persist: bool,
pty: Option<PtySize>,
environment: Environment,
current_dir: Option<PathBuf>,
@ -38,21 +37,12 @@ impl RemoteLspCommand {
/// Creates a new set of options for a remote LSP process
pub fn new() -> Self {
Self {
persist: false,
pty: None,
environment: Environment::new(),
current_dir: None,
}
}
/// Sets whether or not the process will be persistent,
/// meaning that it will not be terminated when the
/// connection to the remote machine is terminated
pub fn persist(&mut self, persist: bool) -> &mut Self {
self.persist = persist;
self
}
/// Configures the process to leverage a PTY with the specified size
pub fn pty(&mut self, pty: Option<PtySize>) -> &mut Self {
self.pty = pty;
@ -81,7 +71,6 @@ impl RemoteLspCommand {
let mut command = RemoteCommand::new();
command.environment(self.environment.clone());
command.current_dir(self.current_dir.clone());
command.persist(self.persist);
command.pty(self.pty);
let mut inner = command.spawn(channel, cmd).await?;
@ -412,7 +401,7 @@ mod tests {
use crate::data::{DistantRequestData, DistantResponseData};
use distant_net::{
common::{FramedTransport, InmemoryTransport, Request, Response},
Client, ReconnectStrategy,
Client,
};
use std::{future::Future, time::Duration};
use test_log::test;
@ -423,7 +412,7 @@ mod tests {
// Configures an lsp process with a means to send & receive data from outside
async fn spawn_lsp_process() -> (FramedTransport<InmemoryTransport>, RemoteLspProcess) {
let (mut t1, t2) = FramedTransport::pair(100);
let client = Client::spawn_inmemory(t2, ReconnectStrategy::Fail);
let client = Client::spawn_inmemory(t2, Default::default());
let spawn_task = tokio::spawn({
let channel = client.clone_channel();
async move {

@ -47,7 +47,6 @@ type StatusResult = io::Result<RemoteStatus>;
/// A [`RemoteProcess`] builder providing support to configure
/// before spawning the process on a remote machine
pub struct RemoteCommand {
persist: bool,
pty: Option<PtySize>,
environment: Environment,
current_dir: Option<PathBuf>,
@ -63,21 +62,12 @@ impl RemoteCommand {
/// Creates a new set of options for a remote process
pub fn new() -> Self {
Self {
persist: false,
pty: None,
environment: Environment::new(),
current_dir: None,
}
}
/// Sets whether or not the process will be persistent,
/// meaning that it will not be terminated when the
/// connection to the remote machine is terminated
pub fn persist(&mut self, persist: bool) -> &mut Self {
self.persist = persist;
self
}
/// Configures the process to leverage a PTY with the specified size
pub fn pty(&mut self, pty: Option<PtySize>) -> &mut Self {
self.pty = pty;
@ -109,7 +99,6 @@ impl RemoteCommand {
.mail(Request::new(DistantMsg::Single(
DistantRequestData::ProcSpawn {
cmd: Cmd::from(cmd),
persist: self.persist,
pty: self.pty,
environment: self.environment.clone(),
current_dir: self.current_dir.clone(),
@ -613,14 +602,14 @@ mod tests {
};
use distant_net::{
common::{FramedTransport, InmemoryTransport, Response},
Client, ReconnectStrategy,
Client,
};
use std::time::Duration;
use test_log::test;
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::pair(100);
(t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail))
(t1, Client::spawn_inmemory(t2, Default::default()))
}
#[test(tokio::test)]

@ -198,7 +198,7 @@ mod tests {
use crate::DistantClient;
use distant_net::{
common::{FramedTransport, InmemoryTransport, Response},
Client, ReconnectStrategy,
Client,
};
use std::{path::PathBuf, sync::Arc};
use test_log::test;
@ -206,7 +206,7 @@ mod tests {
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::pair(100);
(t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail))
(t1, Client::spawn_inmemory(t2, Default::default()))
}
#[test(tokio::test)]

@ -186,7 +186,7 @@ mod tests {
use crate::DistantClient;
use distant_net::{
common::{FramedTransport, InmemoryTransport, Response},
Client, ReconnectStrategy,
Client,
};
use std::sync::Arc;
use test_log::test;
@ -194,7 +194,7 @@ mod tests {
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::pair(100);
(t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail))
(t1, Client::spawn_inmemory(t2, Default::default()))
}
#[test(tokio::test)]

@ -417,12 +417,6 @@ pub enum DistantRequestData {
#[cfg_attr(feature = "clap", clap(long))]
current_dir: Option<PathBuf>,
/// Whether or not the process should be persistent, meaning that the process will not be
/// killed when the associated client disconnects
#[serde(default)]
#[cfg_attr(feature = "clap", clap(long))]
persist: bool,
/// If provided, will spawn process in a pty, otherwise spawns directly
#[serde(default)]
#[cfg_attr(feature = "clap", clap(long))]

@ -45,7 +45,7 @@ impl DistantClientCtx {
// Now initialize our client
let client: DistantClient = Client::build()
.auth_handler(DummyAuthHandler)
.timeout(Duration::from_secs(1))
.connect_timeout(Duration::from_secs(1))
.connector(TcpConnector::new(
format!("{}:{}", ip_addr, port)
.parse::<SocketAddr>()

@ -3,7 +3,7 @@ name = "distant-net"
description = "Network library for distant, providing implementations to support client/server architecture"
categories = ["network-programming"]
keywords = ["api", "async"]
version = "0.20.0-alpha.2"
version = "0.20.0-alpha.3"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"

@ -8,7 +8,7 @@ use std::{
fmt, io,
ops::{Deref, DerefMut},
sync::Arc,
time::Duration,
time::{Duration, Instant},
};
use tokio::{
sync::{mpsc, oneshot, watch},
@ -21,6 +21,9 @@ pub use builder::*;
mod channel;
pub use channel::*;
mod config;
pub use config::*;
mod reconnect;
pub use reconnect::*;
@ -135,18 +138,18 @@ impl UntypedClient {
/// within a program.
pub fn spawn_inmemory(
transport: FramedTransport<InmemoryTransport>,
strategy: ReconnectStrategy,
config: ClientConfig,
) -> Self {
let connection = Connection::Client {
id: rand::random(),
reauth_otp: HeapSecretKey::generate(32).unwrap(),
transport,
};
Self::spawn(connection, strategy)
Self::spawn(connection, config)
}
/// Spawns a client using the provided [`Connection`].
pub(crate) fn spawn<V>(mut connection: Connection<V>, mut strategy: ReconnectStrategy) -> Self
pub(crate) fn spawn<V>(mut connection: Connection<V>, config: ClientConfig) -> Self
where
V: Transport + 'static,
{
@ -164,6 +167,7 @@ impl UntypedClient {
let (watcher_tx, watcher_rx) = watch::channel(ConnectionState::Connected);
let task = tokio::spawn(async move {
let mut needs_reconnect = false;
let mut last_read_frame_time = Instant::now();
// NOTE: We hold onto a copy of the shutdown sender, even though we will never use it,
// to prevent the channel from being closed. This is because we do a check to
@ -171,19 +175,24 @@ impl UntypedClient {
// would cause recv() to resolve immediately and result in the task shutting
// down.
let _shutdown_tx = shutdown_tx_2;
let ClientConfig {
mut reconnect_strategy,
silence_duration,
} = config;
loop {
// If we have flagged that a reconnect is needed, attempt to do so
if needs_reconnect {
info!("Client encountered issue, attempting to reconnect");
if log::log_enabled!(log::Level::Debug) {
debug!("Using strategy {strategy:?}");
debug!("Using strategy {reconnect_strategy:?}");
}
match strategy.reconnect(&mut connection).await {
Ok(x) => {
match reconnect_strategy.reconnect(&mut connection).await {
Ok(()) => {
info!("Client successfully reconnected!");
needs_reconnect = false;
last_read_frame_time = Instant::now();
watcher_tx.send_replace(ConnectionState::Connected);
x
}
Err(x) => {
error!("Unable to re-establish connection: {x}");
@ -193,6 +202,29 @@ impl UntypedClient {
}
}
macro_rules! silence_needs_reconnect {
() => {{
debug!(
"Client exceeded {}s without server activity, so attempting to reconnect",
silence_duration.as_secs_f32(),
);
needs_reconnect = true;
watcher_tx.send_replace(ConnectionState::Reconnecting);
continue;
}};
}
let silence_time_remaining = silence_duration
.checked_sub(last_read_frame_time.elapsed())
.unwrap_or_default();
// NOTE: sleep will not trigger if duration is zero/nanosecond scale, so we
// instead will do an early check here in the case that we need to reconnect
// prior to a sleep while polling the ready status
if silence_time_remaining.as_millis() == 0 {
silence_needs_reconnect!();
}
let ready = tokio::select! {
// NOTE: This should NEVER return None as we never allow the channel to close.
cb = shutdown_rx.recv() => {
@ -202,6 +234,9 @@ impl UntypedClient {
watcher_tx.send_replace(ConnectionState::Disconnected);
break Ok(());
}
_ = tokio::time::sleep(silence_time_remaining) => {
silence_needs_reconnect!();
}
result = connection.ready(Interest::READABLE | Interest::WRITABLE) => {
match result {
Ok(result) => result,
@ -220,7 +255,16 @@ impl UntypedClient {
if ready.is_readable() {
match connection.try_read_frame() {
// If we get an empty frame, we consider this a heartbeat and want
// to adjust our frame read time and discard it from our backup
Ok(Some(frame)) if frame.is_empty() => {
trace!("Client received heartbeat");
last_read_frame_time = Instant::now();
}
// Otherwise, we attempt to parse a frame as a response
Ok(Some(frame)) => {
last_read_frame_time = Instant::now();
match UntypedResponse::from_slice(frame.as_item()) {
Ok(response) => {
if log_enabled!(Level::Trace) {
@ -242,6 +286,7 @@ impl UntypedClient {
}
}
}
Ok(None) => {
debug!("Connection closed");
needs_reconnect = true;
@ -391,9 +436,9 @@ where
/// within a program.
pub fn spawn_inmemory(
transport: FramedTransport<InmemoryTransport>,
strategy: ReconnectStrategy,
config: ClientConfig,
) -> Self {
UntypedClient::spawn_inmemory(transport, strategy).into_typed_client()
UntypedClient::spawn_inmemory(transport, config).into_typed_client()
}
}
@ -515,6 +560,7 @@ impl<T, U> From<Client<T, U>> for Channel<T, U> {
#[cfg(test)]
mod tests {
use super::*;
use crate::client::ClientConfig;
use crate::common::{Ready, Request, Response, TestTransport};
mod typed {
@ -524,12 +570,19 @@ mod tests {
fn spawn_test_client<T>(
connection: Connection<T>,
strategy: ReconnectStrategy,
reconnect_strategy: ReconnectStrategy,
) -> TestClient
where
T: Transport + 'static,
{
UntypedClient::spawn(connection, strategy).into_typed_client()
UntypedClient::spawn(
connection,
ClientConfig {
reconnect_strategy,
..Default::default()
},
)
.into_typed_client()
}
/// Creates a new test transport whose operations do not panic, but do nothing.
@ -848,7 +901,7 @@ mod tests {
async fn should_write_queued_requests_as_outgoing_frames() {
let (client, mut server) = Connection::pair(100);
let mut client = TestClient::spawn(client, ReconnectStrategy::Fail);
let mut client = TestClient::spawn(client, Default::default());
client
.fire(Request::new(1u8).to_untyped_request().unwrap())
.await
@ -908,7 +961,7 @@ mod tests {
.unwrap();
});
let mut client = TestClient::spawn(client, ReconnectStrategy::Fail);
let mut client = TestClient::spawn(client, Default::default());
assert_eq!(
client
.send(Request::new(1u8).to_untyped_request().unwrap())
@ -938,10 +991,13 @@ mod tests {
transport
}),
ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
ClientConfig {
reconnect_strategy: ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
},
..Default::default()
},
);
@ -969,10 +1025,13 @@ mod tests {
transport
}),
ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
ClientConfig {
reconnect_strategy: ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
},
..Default::default()
},
);
@ -1000,10 +1059,13 @@ mod tests {
transport
}),
ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
ClientConfig {
reconnect_strategy: ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
},
..Default::default()
},
);
@ -1031,10 +1093,13 @@ mod tests {
transport
}),
ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
ClientConfig {
reconnect_strategy: ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
},
..Default::default()
},
);
@ -1079,10 +1144,13 @@ mod tests {
transport
}),
ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
ClientConfig {
reconnect_strategy: ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: None,
timeout: None,
},
..Default::default()
},
);
@ -1101,7 +1169,7 @@ mod tests {
// Spawn the client, verify the task is running, kill our server, and verify that the
// client does not block trying to reconnect
let client = TestClient::spawn(client, ReconnectStrategy::Fail);
let client = TestClient::spawn(client, Default::default());
assert!(!client.is_finished(), "Client unexpectedly died");
drop(server);
assert_eq!(
@ -1114,7 +1182,7 @@ mod tests {
async fn should_exit_if_shutdown_signal_detected() {
let (client, _server) = Connection::pair(100);
let client = TestClient::spawn(client, ReconnectStrategy::Fail);
let client = TestClient::spawn(client, Default::default());
client.shutdown().await.unwrap();
// NOTE: We wait for the client's task to conclude by using `wait` to ensure we do not
@ -1142,7 +1210,7 @@ mod tests {
// NOTE: We consume the client to produce a channel without maintaining the shutdown
// channel in order to ensure that dropping the client does not kill the task.
let mut channel = TestClient::spawn(client, ReconnectStrategy::Fail).into_channel();
let mut channel = TestClient::spawn(client, Default::default()).into_channel();
assert_eq!(
channel
.send(Request::new(1u8).to_untyped_request().unwrap())
@ -1154,5 +1222,30 @@ mod tests {
2
);
}
#[test(tokio::test)]
async fn should_attempt_to_reconnect_if_no_activity_from_server_within_silence_duration() {
let (client, _) = Connection::pair(100);
// NOTE: We consume the client to produce a channel without maintaining the shutdown
// channel in order to ensure that dropping the client does not kill the task.
let client = TestClient::spawn(
client,
ClientConfig {
silence_duration: Duration::from_millis(100),
reconnect_strategy: ReconnectStrategy::FixedInterval {
interval: Duration::from_millis(50),
max_retries: Some(3),
timeout: None,
},
},
);
let (tx, mut rx) = mpsc::unbounded_channel();
client.on_connection_change(move |state| tx.send(state).unwrap());
assert_eq!(rx.recv().await, Some(ConnectionState::Reconnecting));
assert_eq!(rx.recv().await, Some(ConnectionState::Disconnected));
assert_eq!(rx.recv().await, None);
}
}
}

@ -13,7 +13,8 @@ mod windows;
#[cfg(windows)]
pub use windows::*;
use crate::client::{Client, ReconnectStrategy, UntypedClient};
use super::ClientConfig;
use crate::client::{Client, UntypedClient};
use crate::common::{authentication::AuthHandler, Connection, Transport};
use async_trait::async_trait;
use std::{convert, io, time::Duration};
@ -40,44 +41,48 @@ impl<T: Transport + 'static> Connector for T {
pub struct ClientBuilder<H, C> {
auth_handler: H,
connector: C,
reconnect_strategy: ReconnectStrategy,
timeout: Option<Duration>,
config: ClientConfig,
connect_timeout: Option<Duration>,
}
impl<H, C> ClientBuilder<H, C> {
/// Configure the authentication handler to use when connecting to a server.
pub fn auth_handler<U>(self, auth_handler: U) -> ClientBuilder<U, C> {
ClientBuilder {
auth_handler,
config: self.config,
connector: self.connector,
reconnect_strategy: self.reconnect_strategy,
timeout: self.timeout,
connect_timeout: self.connect_timeout,
}
}
pub fn connector<U>(self, connector: U) -> ClientBuilder<H, U> {
ClientBuilder {
/// Configure the client-local configuration details.
pub fn config(self, config: ClientConfig) -> Self {
Self {
auth_handler: self.auth_handler,
connector,
reconnect_strategy: self.reconnect_strategy,
timeout: self.timeout,
config,
connector: self.connector,
connect_timeout: self.connect_timeout,
}
}
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> ClientBuilder<H, C> {
/// Configure the connector to use to facilitate connecting to a server.
pub fn connector<U>(self, connector: U) -> ClientBuilder<H, U> {
ClientBuilder {
auth_handler: self.auth_handler,
connector: self.connector,
reconnect_strategy,
timeout: self.timeout,
config: self.config,
connector,
connect_timeout: self.connect_timeout,
}
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
/// Configure a maximum duration to wait for a connection to a server to complete.
pub fn connect_timeout(self, connect_timeout: impl Into<Option<Duration>>) -> Self {
Self {
auth_handler: self.auth_handler,
config: self.config,
connector: self.connector,
reconnect_strategy: self.reconnect_strategy,
timeout: timeout.into(),
connect_timeout: connect_timeout.into(),
}
}
}
@ -86,9 +91,9 @@ impl ClientBuilder<(), ()> {
pub fn new() -> Self {
Self {
auth_handler: (),
reconnect_strategy: ReconnectStrategy::default(),
config: Default::default(),
connector: (),
timeout: None,
connect_timeout: None,
}
}
}
@ -109,11 +114,11 @@ where
/// is fully established and authenticated.
pub async fn connect_untyped(self) -> io::Result<UntypedClient> {
let auth_handler = self.auth_handler;
let retry_strategy = self.reconnect_strategy;
let timeout = self.timeout;
let config = self.config;
let connect_timeout = self.connect_timeout;
let f = async move {
let transport = match timeout {
let transport = match connect_timeout {
Some(duration) => tokio::time::timeout(duration, self.connector.connect())
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
@ -121,10 +126,10 @@ where
None => self.connector.connect().await?,
};
let connection = Connection::client(transport, auth_handler).await?;
Ok(UntypedClient::spawn(connection, retry_strategy))
Ok(UntypedClient::spawn(connection, config))
};
match timeout {
match connect_timeout {
Some(duration) => tokio::time::timeout(duration, f)
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))

@ -0,0 +1,34 @@
use super::ReconnectStrategy;
use std::time::Duration;
const DEFAULT_SILENCE_DURATION: Duration = Duration::from_secs(20);
const MAXIMUM_SILENCE_DURATION: Duration = Duration::from_millis(68719476734);
/// Represents a general-purpose set of properties tied with a client instance.
#[derive(Clone, Debug)]
pub struct ClientConfig {
/// Strategy to use when reconnecting to a server.
pub reconnect_strategy: ReconnectStrategy,
/// A maximum duration to not receive any response/heartbeat from a server before deeming the
/// server as lost and triggering a reconnect.
pub silence_duration: Duration,
}
impl ClientConfig {
pub fn with_maximum_silence_duration(self) -> Self {
Self {
reconnect_strategy: self.reconnect_strategy,
silence_duration: MAXIMUM_SILENCE_DURATION,
}
}
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
reconnect_strategy: ReconnectStrategy::Fail,
silence_duration: DEFAULT_SILENCE_DURATION,
}
}
}

@ -1,4 +1,5 @@
use super::Reconnectable;
use log::*;
use std::io;
use std::time::Duration;
use strum::Display;
@ -170,8 +171,11 @@ impl ReconnectStrategy {
};
// If reconnect was successful, we're done and we can exit
if result.is_ok() {
return Ok(());
match &result {
Ok(()) => return Ok(()),
Err(x) => {
error!("Failed to reconnect: {x}");
}
}
// Decrement remaining retries if we have a limit

@ -91,8 +91,8 @@ where
///
/// For a client, this means performing an actual [`reconnect`] on the underlying
/// [`Transport`], re-establishing an encrypted codec, submitting a request to the server to
/// reauthenticate using a previously-derived OTP, and refreshing the connection id and OTP for
/// use in a future reauthentication.
/// reauthenticate using a previously-derived OTP, and refreshing the OTP for use in a future
/// reauthentication.
///
/// ### Server
///
@ -101,10 +101,10 @@ where
/// [`reconnect`]: Reconnectable::reconnect
async fn reconnect(&mut self) -> io::Result<()> {
async fn reconnect_client<T: Transport>(
id: &mut ConnectionId,
reauth_otp: &mut HeapSecretKey,
id: ConnectionId,
reauth_otp: HeapSecretKey,
transport: &mut FramedTransport<T>,
) -> io::Result<()> {
) -> io::Result<(ConnectionId, HeapSecretKey)> {
// Re-establish a raw connection
debug!("[Conn {id}] Re-establishing connection");
Reconnectable::reconnect(transport).await?;
@ -117,29 +117,16 @@ where
debug!("[Conn {id}] Performing re-authentication");
transport
.write_frame_for(&ConnectType::Reconnect {
id: *id,
otp: reauth_otp.unprotected_as_bytes().to_vec(),
id,
otp: reauth_otp.unprotected_into_bytes(),
})
.await?;
// Receive the new id for the connection
// NOTE: If we fail re-authentication above,
// this will fail as the connection is dropped
debug!("[Conn {id}] Receiving new connection id");
let new_id = transport
.read_frame_as::<ConnectionId>()
.await?
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Missing connection id frame")
})?;
debug!("[Conn {id}] Resetting id to {new_id}");
*id = new_id;
// Derive an OTP for reauthentication
debug!("[Conn {id}] Deriving future OTP for reauthentication");
*reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
Ok(())
Ok((id, reauth_otp))
}
match self {
@ -148,19 +135,24 @@ where
transport,
reauth_otp,
} => {
// Freeze our backup as we don't want the connection logic to alter it
transport.backup.freeze();
// Attempt to perform the reconnection and unfreeze our backup regardless of the
// result
let result = reconnect_client(id, reauth_otp, transport).await;
transport.backup.unfreeze();
result?;
// Freeze our backup as we don't want the connection logic to alter it, attempt to
// perform the reconnection, and unfreeze our backup regardless of the result
let (new_id, new_reauth_otp) = {
transport.backup.freeze();
let result = reconnect_client(*id, reauth_otp.clone(), transport).await;
transport.backup.unfreeze();
result?
};
// Perform synchronization
debug!("[Conn {id}] Synchronizing frame state");
transport.synchronize().await?;
// Everything has succeeded, so we now will update our id and reauth otp
info!("[Conn {id}] Reconnect completed successfully! Assigning new id {new_id}");
*id = new_id;
*reauth_otp = new_reauth_otp;
Ok(())
}
@ -234,6 +226,7 @@ where
debug!("[Conn {id}] Deriving future OTP for reauthentication");
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
info!("[Conn {id}] Connect completed successfully!");
Ok(Self::Client {
id,
reauth_otp,
@ -283,7 +276,7 @@ where
// Based on the connection type, we either try to find and validate an existing connection
// or we perform normal verification
match connection_type {
let id = match connection_type {
ConnectType::Connect => {
// Communicate the connection id
debug!("[Conn {id}] Telling other side to change connection id");
@ -298,42 +291,69 @@ where
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
// Store the id, OTP, and backup retrieval in our database
info!("[Conn {id}] Connect completed successfully!");
keychain.insert(id.to_string(), reauth_otp, rx).await;
id
}
ConnectType::Reconnect { id: other_id, otp } => {
let reauth_otp = HeapSecretKey::from(otp);
debug!("[Conn {id}] Checking if {other_id} exists and has matching OTP");
match keychain
.remove_if_has_key(other_id.to_string(), reauth_otp)
.remove_if_has_key(other_id.to_string(), reauth_otp.clone())
.await
{
KeychainResult::Ok(x) => {
// Communicate the connection id
debug!("[Conn {id}] Telling other side to change connection id");
transport.write_frame_for(&id).await?;
// Derive an OTP for reauthentication
debug!("[Conn {id}] Deriving future OTP for reauthentication");
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
// Match found, so we want ot update our id to be the pre-existing id
debug!("[Conn {id}] Reassigning to {other_id}");
let id = other_id;
// Grab the old backup and swap it into our transport
// Grab the old backup
debug!("[Conn {id}] Acquiring backup for existing connection");
match x.await {
Ok(backup) => {
transport.backup = backup;
}
let backup = match x.await {
Ok(backup) => backup,
Err(_) => {
warn!("[Conn {id}] Missing backup");
warn!("[Conn {id}] Missing backup, will use fresh copy");
Backup::new()
}
};
macro_rules! unwrap_or_fail {
($action:expr) => {
unwrap_or_fail!(backup, $action)
};
($backup:expr, $action:expr) => {{
match $action {
Ok(x) => x,
Err(x) => {
error!("[Conn {id}] Encountered error, restoring with old backup");
let _ = tx.send($backup);
keychain.insert(id.to_string(), reauth_otp, rx).await;
return Err(x);
}
}
}};
}
// Derive an OTP for reauthentication
debug!("[Conn {id}] Deriving future OTP for reauthentication");
let new_reauth_otp =
unwrap_or_fail!(transport.exchange_keys().await).into_heap_secret_key();
// Replace our backup with the old one
debug!("[Conn {id}] Restoring backup");
transport.backup = backup;
// Synchronize using the provided backup
debug!("[Conn {id}] Synchronizing frame state");
transport.synchronize().await?;
unwrap_or_fail!(transport.backup, transport.synchronize().await);
// Store the id, OTP, and backup retrieval in our database
keychain.insert(id.to_string(), reauth_otp, rx).await;
info!("[Conn {id}] Reconnect restoration completed successfully!");
keychain.insert(id.to_string(), new_reauth_otp, rx).await;
id
}
KeychainResult::InvalidPassword => {
return Err(io::Error::new(
@ -349,7 +369,7 @@ where
}
}
}
}
};
Ok(Self::Server { id, tx, transport })
}
@ -384,7 +404,6 @@ impl Connection<InmemoryTransport> {
}
}
#[cfg(test)]
impl<T> Connection<T> {
/// Returns the id of the connection.
pub fn id(&self) -> ConnectionId {
@ -393,7 +412,10 @@ impl<T> Connection<T> {
Self::Server { id, .. } => *id,
}
}
}
#[cfg(test)]
impl<T> Connection<T> {
/// Returns the OTP associated with the connection, or none if connection is server-side.
pub fn otp(&self) -> Option<&HeapSecretKey> {
match self {
@ -821,9 +843,6 @@ mod tests {
.await
.unwrap();
// Receive a new client id
let _id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Send garbage to fail the otp exchange
t1.write_frame(Frame::new(b"hello")).await.unwrap();
@ -862,9 +881,6 @@ mod tests {
.await
.unwrap();
// Receive a new client id
let _id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Perform otp exchange
let _otp = t1.exchange_keys().await.unwrap();
@ -928,9 +944,10 @@ mod tests {
let verifier = Verifier::none();
let keychain = Keychain::new();
let key = HeapSecretKey::generate(32).unwrap();
let id = 1234;
keychain
.insert(1234.to_string(), key.clone(), {
.insert(id.to_string(), key.clone(), {
// Create a custom backup we'll use to replay frames from the server-side
let mut backup = Backup::new();
@ -968,9 +985,6 @@ mod tests {
.await
.unwrap();
// Receive a new client id
let id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Perform otp exchange
let otp = t1.exchange_keys().await.unwrap();
@ -996,9 +1010,6 @@ mod tests {
// Validate the connection ids match
assert_eq!(server.id(), id);
// Check that our old connection id is no longer contained in the keychain
assert!(!keychain.has_id("1234").await, "Old OTP still exists");
// Validate the OTP was stored in our keychain
assert!(
keychain
@ -1210,12 +1221,6 @@ mod tests {
.await
.expect("Failed to retrieve backup");
// Send a new id back to the client connection
transport
.write_frame_for(&rand::random::<ConnectionId>())
.await
.unwrap();
// Perform key exchange
let otp = transport.exchange_keys().await.unwrap();

@ -1,5 +1,5 @@
use async_trait::async_trait;
use std::{io, time::Duration};
use std::{fmt, io, time::Duration};
mod framed;
pub use framed::*;
@ -42,7 +42,7 @@ pub trait Reconnectable {
/// Interface representing a transport of raw bytes into and out of the system.
#[async_trait]
pub trait Transport: Reconnectable + Send + Sync {
pub trait Transport: Reconnectable + fmt::Debug + Send + Sync {
/// Tries to read data from the transport into the provided buffer, returning how many bytes
/// were read.
///

@ -254,12 +254,13 @@ impl<T: Transport> FramedTransport<T> {
macro_rules! read_next_frame {
() => {{
match Frame::read(&mut self.incoming) {
Ok(None) => (),
Ok(Some(frame)) => {
self.backup.increment_received_cnt();
None => (),
Some(frame) => {
if frame.is_nonempty() {
self.backup.increment_received_cnt();
}
return Ok(Some(self.codec.decode(frame)?.into_owned()));
}
Err(x) => return Err(x),
}
}};
}
@ -363,14 +364,17 @@ impl<T: Transport> FramedTransport<T> {
// Encode the frame and store it in our outgoing queue
self.codec
.encode(frame.as_borrowed())?
.write(&mut self.outgoing)?;
.write(&mut self.outgoing);
// Once the frame enters our queue, we count it as written, even if it isn't fully flushed
self.backup.increment_sent_cnt();
// Update tracking stats and more of backup if frame is nonempty
if frame.is_nonempty() {
// Once the frame enters our queue, we count it as written, even if it isn't fully flushed
self.backup.increment_sent_cnt();
// Then we store the raw frame (non-encoded) for the future in case we need to retry
// sending it later (possibly with a different codec)
self.backup.push_frame(frame);
// Then we store the raw frame (non-encoded) for the future in case we need to retry
// sending it later (possibly with a different codec)
self.backup.push_frame(frame);
}
// Attempt to write everything in our queue
self.try_flush()?;
@ -535,7 +539,7 @@ impl<T: Transport> FramedTransport<T> {
// Encode our frame and write it to be queued in our incoming data
// NOTE: We have to do encoding here as incoming bytes are expected to be encoded
this.codec.encode(frame)?.write(&mut this.incoming)?;
this.codec.encode(frame)?.write(&mut this.incoming);
}
// Catch up our read count as we can have the case where the other side has a higher
@ -701,7 +705,7 @@ impl<T: Transport> FramedTransport<T> {
})?;
// Choose a compression and encryption option from the options
debug!("[{log_label}] Selecting from options: {options:#?}");
debug!("[{log_label}] Selecting from options: {options:?}");
let choice = Choice {
// Use preferred compression if available, otherwise default to no compression
// to avoid choosing something poor
@ -725,7 +729,7 @@ impl<T: Transport> FramedTransport<T> {
};
// Report back to the server the choice
debug!("[{log_label}] Reporting choice: {choice:#?}");
debug!("[{log_label}] Reporting choice: {choice:?}");
self.write_frame_for(&choice).await?;
choice
@ -740,7 +744,7 @@ impl<T: Transport> FramedTransport<T> {
};
// Send options to the client
debug!("[{log_label}] Sending options: {options:#?}");
debug!("[{log_label}] Sending options: {options:?}");
self.write_frame_for(&options).await?;
// Get client's response with selected compression and encryption
@ -754,7 +758,7 @@ impl<T: Transport> FramedTransport<T> {
}
};
debug!("[{log_label}] Building compression & encryption codecs based on {choice:#?}");
debug!("[{log_label}] Building compression & encryption codecs based on {choice:?}");
let compression_level = choice.compression_level.unwrap_or_default();
// Acquire a codec for the compression type
@ -968,7 +972,7 @@ mod tests {
let mut buf = BytesMut::new();
for frame in frames {
frame.write(&mut buf).unwrap();
frame.write(&mut buf);
}
buf.to_vec()
@ -1059,7 +1063,7 @@ mod tests {
fn try_read_frame_should_return_next_available_frame() {
let data = {
let mut data = BytesMut::new();
Frame::new(b"hello world").write(&mut data).unwrap();
Frame::new(b"hello world").write(&mut data);
data.freeze()
};
@ -1082,8 +1086,8 @@ mod tests {
// Store two frames in our data to transmit
let data = {
let mut data = BytesMut::new();
Frame::new(b"hello world").write(&mut data).unwrap();
Frame::new(b"hello again").write(&mut data).unwrap();
Frame::new(b"hello world").write(&mut data);
Frame::new(b"hello again").write(&mut data);
data.freeze()
};
@ -1746,8 +1750,8 @@ mod tests {
let (mut t1, mut t2) = FramedTransport::pair(100);
// Put some frames into the incoming and outgoing of our transport
Frame::new(b"bad incoming").write(&mut t2.incoming).unwrap();
Frame::new(b"bad outgoing").write(&mut t2.outgoing).unwrap();
Frame::new(b"bad incoming").write(&mut t2.incoming);
Frame::new(b"bad outgoing").write(&mut t2.outgoing);
// Configure the backup such that we have sent two frames
t2.backup.push_frame(Frame::new(b"hello"));

@ -5,6 +5,11 @@ use std::collections::VecDeque;
const MAX_BACKUP_SIZE: usize = 256 * 1024 * 1024;
/// Stores [`Frame`]s for reuse later.
///
/// ### Note
///
/// Empty [`Frame`]s are an exception and are not stored within the backup nor
/// are they tracked in terms of sent/received counts.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Backup {
/// Maximum size (in bytes) to save frames in case we need to backup them

@ -1,5 +1,5 @@
use bytes::{Buf, BufMut, BytesMut};
use std::{borrow::Cow, io};
use std::borrow::Cow;
/// Represents a frame whose lifetime is static
pub type OwnedFrame = Frame<'static>;
@ -13,7 +13,7 @@ pub struct Frame<'a> {
}
impl<'a> Frame<'a> {
/// Creates a new frame wrapping the `item` that will be shipped across the network
/// Creates a new frame wrapping the `item` that will be shipped across the network.
pub fn new(item: &'a [u8]) -> Self {
Self {
item: Cow::Borrowed(item),
@ -27,75 +27,66 @@ impl<'a> Frame<'a> {
}
impl Frame<'_> {
/// Total bytes to use as the header field denoting a frame's size
/// Total bytes to use as the header field denoting a frame's size.
pub const HEADER_SIZE: usize = 8;
/// Returns the len (in bytes) of the item wrapped by the frame
/// Creates a new frame with no item.
pub fn empty() -> Self {
Self::new(&[])
}
/// Returns the len (in bytes) of the item wrapped by the frame.
pub fn len(&self) -> usize {
self.item.len()
}
/// Returns true if the frame is comprised of zero bytes
/// Returns true if the frame is comprised of zero bytes.
pub fn is_empty(&self) -> bool {
self.item.is_empty()
}
/// Returns a reference to the bytes of the frame's item
/// Returns true if the frame is comprised of some bytes.
#[inline]
pub fn is_nonempty(&self) -> bool {
!self.is_empty()
}
/// Returns a reference to the bytes of the frame's item.
pub fn as_item(&self) -> &[u8] {
&self.item
}
/// Writes the frame to a new [`Vec`] of bytes, returning them on success
pub fn try_to_bytes(&self) -> io::Result<Vec<u8>> {
/// Writes the frame to a new [`Vec`] of bytes, returning them on success.
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = BytesMut::new();
self.write(&mut bytes)?;
Ok(bytes.to_vec())
self.write(&mut bytes);
bytes.to_vec()
}
/// Writes the frame to the end of `dst`, including the header representing the length of the
/// item as part of the written bytes
pub fn write(&self, dst: &mut BytesMut) -> io::Result<()> {
if self.item.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Empty item provided",
));
}
/// item as part of the written bytes.
pub fn write(&self, dst: &mut BytesMut) {
dst.reserve(Self::HEADER_SIZE + self.item.len());
// Add data in form of {LEN}{ITEM}
dst.put_u64((self.item.len()) as u64);
dst.put_slice(&self.item);
Ok(())
}
/// Attempts to read a frame from `src`, returning `Some(Frame)` if a frame was found
/// (including the header) or `None` if the current `src` does not contain a frame
pub fn read(src: &mut BytesMut) -> io::Result<Option<OwnedFrame>> {
/// (including the header) or `None` if the current `src` does not contain a frame.
pub fn read(src: &mut BytesMut) -> Option<OwnedFrame> {
// First, check if we have more data than just our frame's message length
if src.len() <= Self::HEADER_SIZE {
return Ok(None);
return None;
}
// Second, retrieve total size of our frame's message
let item_len = u64::from_be_bytes(src[..Self::HEADER_SIZE].try_into().unwrap()) as usize;
// In the case that our item len is 0, we skip over the invalid frame
if item_len == 0 {
// Ensure we advance to remove the frame
src.advance(Self::HEADER_SIZE);
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Frame's msg cannot have length of 0",
));
}
// Third, check if we have all data for our frame; if not, exit early
if src.len() < item_len + Self::HEADER_SIZE {
return Ok(None);
return None;
}
// Fourth, get and return our item
@ -104,13 +95,13 @@ impl Frame<'_> {
// Fifth, advance so frame is no longer kept around
src.advance(Self::HEADER_SIZE + item_len);
Ok(Some(Frame::from(item)))
Some(Frame::from(item))
}
/// Checks if a full frame is available from `src`, returning true if a frame was found false
/// if the current `src` does not contain a frame. Does not consume the frame.
pub fn available(src: &BytesMut) -> bool {
matches!(Frame::read(&mut src.clone()), Ok(Some(_)))
matches!(Frame::read(&mut src.clone()), Some(_))
}
/// Returns a new frame which is identical but has a lifetime tied to this frame.
@ -239,16 +230,15 @@ mod tests {
use test_log::test;
#[test]
fn write_should_fail_when_item_is_zero_bytes() {
fn write_should_succeed_when_item_is_zero_bytes() {
let frame = Frame::new(&[]);
let mut buf = BytesMut::new();
let result = frame.write(&mut buf);
frame.write(&mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidInput => {}
x => panic!("Unexpected result: {:?}", x),
}
// Writing a frame of zero bytes means that the header is all zeros and there is
// no item that follows the header
assert_eq!(buf.as_ref(), &[0, 0, 0, 0, 0, 0, 0, 0]);
}
#[test]
@ -256,7 +246,7 @@ mod tests {
let frame = Frame::new(b"hello, world");
let mut buf = BytesMut::new();
frame.write(&mut buf).expect("Failed to write");
frame.write(&mut buf);
let len = buf.get_u64() as usize;
assert_eq!(len, 12, "Wrong length writed");
@ -269,11 +259,7 @@ mod tests {
buf.put_bytes(0, Frame::HEADER_SIZE);
let result = Frame::read(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
assert!(matches!(result, None), "Unexpected result: {:?}", result);
}
#[test]
@ -282,24 +268,21 @@ mod tests {
buf.put_u64(0);
let result = Frame::read(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
assert!(matches!(result, None), "Unexpected result: {:?}", result);
}
#[test]
fn read_should_fail_if_writed_item_length_is_zero() {
fn read_should_succeed_if_written_item_length_is_zero() {
let mut buf = BytesMut::new();
buf.put_u64(0);
buf.put_u8(255);
let result = Frame::read(&mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
// Reading will result in a frame of zero bytes
let frame = Frame::read(&mut buf).expect("missing frame");
assert_eq!(frame, Frame::empty());
// Nothing following the frame header should have been extracted
assert_eq!(buf.as_ref(), &[255]);
}
#[test]
@ -308,10 +291,7 @@ mod tests {
buf.put_u64(0);
buf.put_bytes(0, 3);
assert!(
Frame::read(&mut buf).is_err(),
"read unexpectedly succeeded"
);
assert_eq!(Frame::read(&mut buf).unwrap(), Frame::empty());
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
@ -319,25 +299,22 @@ mod tests {
fn read_should_advance_src_by_frame_size_when_successful() {
// Add 3 extra bytes after a full frame
let mut buf = BytesMut::new();
Frame::new(b"hello, world")
.write(&mut buf)
.expect("Failed to write");
Frame::new(b"hello, world").write(&mut buf);
buf.put_bytes(0, 3);
assert!(Frame::read(&mut buf).is_ok(), "read unexpectedly failed");
assert!(
Frame::read(&mut buf).is_some(),
"read unexpectedly missing frame"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn read_should_return_some_byte_vec_when_successful() {
let mut buf = BytesMut::new();
Frame::new(b"hello, world")
.write(&mut buf)
.expect("Failed to write");
Frame::new(b"hello, world").write(&mut buf);
let item = Frame::read(&mut buf)
.expect("Failed to read")
.expect("Item not properly captured");
let item = Frame::read(&mut buf).expect("missing frame");
assert_eq!(item, b"hello, world");
}
}

@ -1,6 +1,6 @@
use super::{Interest, Ready, Reconnectable, Transport};
use async_trait::async_trait;
use std::io;
use std::{fmt, io};
pub type TryReadFn = Box<dyn Fn(&mut [u8]) -> io::Result<usize> + Send + Sync>;
pub type TryWriteFn = Box<dyn Fn(&[u8]) -> io::Result<usize> + Send + Sync>;
@ -25,6 +25,12 @@ impl Default for TestTransport {
}
}
impl fmt::Debug for TestTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TestTransport").finish()
}
}
#[async_trait]
impl Reconnectable for TestTransport {
async fn reconnect(&mut self) -> io::Result<()> {

@ -303,13 +303,13 @@ impl ManagerClient {
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{ReconnectStrategy, UntypedClient};
use crate::client::UntypedClient;
use crate::common::authentication::DummyAuthHandler;
use crate::common::{Connection, InmemoryTransport, Request, Response};
fn setup() -> (ManagerClient, Connection<InmemoryTransport>) {
let (client, server) = Connection::pair(100);
let client = UntypedClient::spawn(client, ReconnectStrategy::Fail).into_typed_client();
let client = UntypedClient::spawn(client, Default::default()).into_typed_client();
(client, server)
}

@ -1,5 +1,5 @@
use crate::{
client::{Client, ReconnectStrategy, UntypedClient},
client::{Client, ClientConfig, UntypedClient},
common::{ConnectionId, FramedTransport, InmemoryTransport, UntypedRequest},
manager::data::{ManagerRequest, ManagerResponse},
};
@ -35,7 +35,10 @@ impl RawChannel {
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
Client::spawn_inmemory(self.transport, ReconnectStrategy::Fail)
Client::spawn_inmemory(
self.transport,
ClientConfig::default().with_maximum_silence_duration(),
)
}
/// Consumes this channel, returning an untyped client wrapping the transport.
@ -46,7 +49,10 @@ impl RawChannel {
/// performed during separate connection and this merely wraps an inmemory transport that maps
/// to the primary connection.
pub fn into_untyped_client(self) -> UntypedClient {
UntypedClient::spawn_inmemory(self.transport, ReconnectStrategy::Fail)
UntypedClient::spawn_inmemory(
self.transport,
ClientConfig::default().with_maximum_silence_duration(),
)
}
/// Returns reference to the underlying framed transport.

@ -316,7 +316,7 @@ impl ServerHandler for ManagerServer {
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{ReconnectStrategy, UntypedClient};
use crate::client::UntypedClient;
use crate::common::FramedTransport;
use crate::server::ServerReply;
use crate::{boxed_connect_handler, boxed_launch_handler};
@ -335,7 +335,7 @@ mod tests {
/// Create an untyped client that is detached such that reads and writes will fail
fn detached_untyped_client() -> UntypedClient {
UntypedClient::spawn_inmemory(FramedTransport::pair(1).0, ReconnectStrategy::Fail)
UntypedClient::spawn_inmemory(FramedTransport::pair(1).0, Default::default())
}
/// Create a new server and authenticator

@ -1,9 +1,9 @@
use crate::common::{authentication::Verifier, Listener, Transport};
use crate::common::{authentication::Verifier, Listener, Response, Transport};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};
use std::{io, sync::Arc, time::Duration};
use tokio::sync::RwLock;
use tokio::sync::{broadcast, RwLock};
mod builder;
pub use builder::*;
@ -148,14 +148,20 @@ where
L::Output: Transport + 'static,
{
let state = Arc::new(ServerState::new());
let task = tokio::spawn(self.task(Arc::clone(&state), listener));
let (tx, rx) = broadcast::channel(1);
let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
Ok(Box::new(GenericServerRef { state, task }))
Ok(Box::new(GenericServerRef { shutdown: tx, task }))
}
/// Internal task that is run to receive connections and spawn connection tasks
async fn task<L>(self, state: Arc<ServerState>, mut listener: L)
where
async fn task<L>(
self,
state: Arc<ServerState<Response<T::Response>>>,
mut listener: L,
shutdown_tx: broadcast::Sender<()>,
shutdown_rx: broadcast::Receiver<()>,
) where
L: Listener + 'static,
L::Output: Transport + 'static,
{
@ -171,6 +177,7 @@ where
let timer = Arc::new(RwLock::new(timer));
let verifier = Arc::new(verifier);
let mut connection_tasks = Vec::new();
loop {
// Receive a new connection, exiting if no longer accepting connections or if the shutdown
// signal has been received
@ -191,10 +198,7 @@ where
config.shutdown.duration().unwrap_or_default().as_secs_f32(),
);
for (id, task) in state.connections.write().await.drain() {
info!("Terminating task {id}");
task.abort();
}
let _ = shutdown_tx.send(());
break;
}
@ -203,26 +207,28 @@ where
// Ensure that the shutdown timer is cancelled now that we have a connection
timer.read().await.stop();
let connection = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(state.keychain.clone())
.transport(transport)
.shutdown_timer(Arc::downgrade(&timer))
.sleep_duration(config.connection_sleep)
.verifier(Arc::downgrade(&verifier))
.spawn();
state
.connections
.write()
.await
.insert(connection.id(), connection);
connection_tasks.push(
ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(state.keychain.clone())
.transport(transport)
.shutdown(shutdown_rx.resubscribe())
.shutdown_timer(Arc::downgrade(&timer))
.sleep_duration(config.connection_sleep)
.heartbeat_duration(config.connection_heartbeat)
.verifier(Arc::downgrade(&verifier))
.spawn(),
);
}
// Once we stop listening, we still want to wait until all connections have terminated
info!("Server waiting for active connections to terminate");
while state.has_active_connections().await {
loop {
connection_tasks.retain(|task| !task.is_finished());
if connection_tasks.is_empty() {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
info!("Server task terminated");

@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
use std::{num::ParseFloatError, str::FromStr, time::Duration};
const DEFAULT_CONNECTION_SLEEP: Duration = Duration::from_millis(1);
const DEFAULT_HEARTBEAT_DURATION: Duration = Duration::from_secs(5);
/// Represents a general-purpose set of properties tied with a server instance
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
@ -10,6 +11,9 @@ pub struct ServerConfig {
/// Time to wait inbetween connection read/write when nothing was read or written on last pass
pub connection_sleep: Duration,
/// Minimum time to wait inbetween sending heartbeat messages
pub connection_heartbeat: Duration,
/// Rules for how a server will shutdown automatically
pub shutdown: Shutdown,
}
@ -18,6 +22,7 @@ impl Default for ServerConfig {
fn default() -> Self {
Self {
connection_sleep: DEFAULT_CONNECTION_SLEEP,
connection_heartbeat: DEFAULT_HEARTBEAT_DURATION,
shutdown: Default::default(),
}
}

@ -1,7 +1,10 @@
use super::{ConnectionCtx, ServerCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer};
use super::{
ConnectionCtx, ConnectionState, ServerCtx, ServerHandler, ServerReply, ServerState,
ShutdownTimer,
};
use crate::common::{
authentication::{Keychain, Verifier},
Backup, Connection, ConnectionId, Interest, Response, Transport, UntypedRequest,
Backup, Connection, Frame, Interest, Response, Transport, UntypedRequest,
};
use log::*;
use serde::{de::DeserializeOwned, Serialize};
@ -11,56 +14,33 @@ use std::{
pin::Pin,
sync::{Arc, Weak},
task::{Context, Poll},
time::Duration,
time::{Duration, Instant},
};
use tokio::{
sync::{mpsc, oneshot, RwLock},
sync::{broadcast, mpsc, oneshot, RwLock},
task::JoinHandle,
};
pub type ServerKeychain = Keychain<oneshot::Receiver<Backup>>;
/// Time to wait inbetween connection read/write when nothing was read or written on last pass
/// Time to wait inbetween connection read/write when nothing was read or written on last pass.
const SLEEP_DURATION: Duration = Duration::from_millis(1);
/// Represents an individual connection on the server
pub struct ConnectionTask {
/// Unique identifier tied to the connection
id: ConnectionId,
/// Minimum time between heartbeats to communicate to the client connection.
const MINIMUM_HEARTBEAT_DURATION: Duration = Duration::from_secs(5);
/// Task that is processing requests and responses
task: JoinHandle<io::Result<()>>,
}
/// Represents an individual connection on the server.
pub(super) struct ConnectionTask(JoinHandle<io::Result<()>>);
impl ConnectionTask {
/// Starts building a new connection
pub fn build() -> ConnectionTaskBuilder<(), ()> {
let id: ConnectionId = rand::random();
ConnectionTaskBuilder {
id,
handler: Weak::new(),
state: Weak::new(),
keychain: Keychain::new(),
transport: (),
shutdown_timer: Weak::new(),
sleep_duration: SLEEP_DURATION,
verifier: Weak::new(),
}
}
/// Returns the id associated with the connection
pub fn id(&self) -> ConnectionId {
self.id
pub fn build() -> ConnectionTaskBuilder<(), (), ()> {
ConnectionTaskBuilder::new()
}
/// Returns true if the task has finished
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
/// Aborts the connection
pub fn abort(&self) {
self.task.abort();
self.0.is_finished()
}
}
@ -68,7 +48,7 @@ impl Future for ConnectionTask {
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Future::poll(Pin::new(&mut self.task), cx) {
match Future::poll(Pin::new(&mut self.0), cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(x) => match x {
Ok(x) => Poll::Ready(x),
@ -78,114 +58,171 @@ impl Future for ConnectionTask {
}
}
pub struct ConnectionTaskBuilder<H, T> {
id: ConnectionId,
/// Represents a builder for a new connection task.
pub(super) struct ConnectionTaskBuilder<H, S, T> {
handler: Weak<H>,
state: Weak<ServerState>,
state: Weak<ServerState<S>>,
keychain: Keychain<oneshot::Receiver<Backup>>,
transport: T,
shutdown: broadcast::Receiver<()>,
shutdown_timer: Weak<RwLock<ShutdownTimer>>,
sleep_duration: Duration,
heartbeat_duration: Duration,
verifier: Weak<Verifier>,
}
impl<H, T> ConnectionTaskBuilder<H, T> {
pub fn handler<U>(self, handler: Weak<U>) -> ConnectionTaskBuilder<U, T> {
impl ConnectionTaskBuilder<(), (), ()> {
/// Starts building a new connection.
pub fn new() -> Self {
Self {
handler: Weak::new(),
state: Weak::new(),
keychain: Keychain::new(),
transport: (),
shutdown: broadcast::channel(1).1,
shutdown_timer: Weak::new(),
sleep_duration: SLEEP_DURATION,
heartbeat_duration: MINIMUM_HEARTBEAT_DURATION,
verifier: Weak::new(),
}
}
}
impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
pub fn handler<U>(self, handler: Weak<U>) -> ConnectionTaskBuilder<U, S, T> {
ConnectionTaskBuilder {
id: self.id,
handler,
state: self.state,
keychain: self.keychain,
transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
}
}
pub fn state(self, state: Weak<ServerState>) -> ConnectionTaskBuilder<H, T> {
pub fn state<U>(self, state: Weak<ServerState<U>>) -> ConnectionTaskBuilder<H, U, T> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state,
keychain: self.keychain,
transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
}
}
pub fn keychain(self, keychain: ServerKeychain) -> ConnectionTaskBuilder<H, T> {
pub fn keychain(self, keychain: ServerKeychain) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state: self.state,
keychain,
transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
}
}
pub fn transport<U>(self, transport: U) -> ConnectionTaskBuilder<H, U> {
pub fn transport<U>(self, transport: U) -> ConnectionTaskBuilder<H, S, U> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
keychain: self.keychain,
state: self.state,
transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
}
}
pub fn shutdown(self, shutdown: broadcast::Receiver<()>) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder {
handler: self.handler,
state: self.state,
keychain: self.keychain,
transport: self.transport,
shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
}
}
pub(crate) fn shutdown_timer(
pub fn shutdown_timer(
self,
shutdown_timer: Weak<RwLock<ShutdownTimer>>,
) -> ConnectionTaskBuilder<H, T> {
) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state: self.state,
keychain: self.keychain,
transport: self.transport,
shutdown: self.shutdown,
shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
}
}
pub fn sleep_duration(self, sleep_duration: Duration) -> ConnectionTaskBuilder<H, T> {
pub fn sleep_duration(self, sleep_duration: Duration) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state: self.state,
keychain: self.keychain,
transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
}
}
pub fn heartbeat_duration(
self,
heartbeat_duration: Duration,
) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder {
handler: self.handler,
state: self.state,
keychain: self.keychain,
transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration,
verifier: self.verifier,
}
}
pub fn verifier(self, verifier: Weak<Verifier>) -> ConnectionTaskBuilder<H, T> {
pub fn verifier(self, verifier: Weak<Verifier>) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state: self.state,
keychain: self.keychain,
transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier,
}
}
}
impl<H, T> ConnectionTaskBuilder<H, T>
impl<H, T> ConnectionTaskBuilder<H, Response<H::Response>, T>
where
H: ServerHandler + Sync + 'static,
H::Request: DeserializeOwned + Send + Sync + 'static,
@ -194,52 +231,86 @@ where
T: Transport + 'static,
{
pub fn spawn(self) -> ConnectionTask {
let id = self.id;
ConnectionTask {
id,
task: tokio::spawn(self.run()),
}
ConnectionTask(tokio::spawn(self.run()))
}
async fn run(self) -> io::Result<()> {
let ConnectionTaskBuilder {
id,
handler,
state,
keychain,
transport,
mut shutdown,
shutdown_timer,
sleep_duration,
heartbeat_duration,
verifier,
} = self;
// NOTE: This exists purely to make the compiler happy for macro_rules declaration order.
let (mut local_shutdown, channel_tx, connection_state) = ConnectionState::channel();
// Will check if no more connections and restart timer if that's the case
macro_rules! terminate_connection {
// Prints an error message before terminating the connection by panicking
(@error $($msg:tt)+) => {
// Prints an error message and does not store state
(@fatal $($msg:tt)+) => {
error!($($msg)+);
terminate_connection!();
return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+)));
};
// Prints a debug message before terminating the connection by cleanly returning
(@debug $($msg:tt)+) => {
// Prints an error message and stores state before terminating
(@error($tx:ident, $rx:ident) $($msg:tt)+) => {
error!($($msg)+);
terminate_connection!($tx, $rx);
return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+)));
};
// Prints a debug message and stores state before terminating
(@debug($tx:ident, $rx:ident) $($msg:tt)+) => {
debug!($($msg)+);
terminate_connection!($tx, $rx);
return Ok(());
};
// Prints a shutdown message with no connection id and exit without sending state
(@shutdown) => {
debug!("Shutdown triggered before a connection could be fully established");
terminate_connection!();
return Ok(());
};
// Prints a shutdown message with no connection id and stores state before terminating
(@shutdown) => {
debug!("Shutdown triggered before a connection could be fully established");
terminate_connection!();
return Ok(());
};
// Prints a shutdown message and stores state before terminating
(@shutdown($id:ident, $tx:ident, $rx:ident)) => {{
debug!("[Conn {}] Shutdown triggered", $id);
terminate_connection!($tx, $rx);
return Ok(());
}};
// Performs the connection termination by removing it from server state and
// restarting the shutdown timer if it was the last connection
($tx:ident, $rx:ident) => {
// Send the channels back
let _ = channel_tx.send(($tx, $rx));
terminate_connection!();
};
// Performs the connection termination by removing it from server state and
// restarting the shutdown timer if it was the last connection
() => {
// Remove the connection from our state if it has closed
// Restart our shutdown timer if this is the last connection
if let Some(state) = Weak::upgrade(&state) {
state.connections.write().await.remove(&self.id);
// If we have no more connections, start the timer
if let Some(timer) = Weak::upgrade(&shutdown_timer) {
if state.connections.read().await.is_empty() {
if state.connections.read().await.values().filter(|conn| !conn.is_finished()).count() <= 1 {
debug!("Last connection terminating, so restarting shutdown timer");
timer.write().await.restart();
}
}
@ -247,58 +318,160 @@ where
};
}
/// Awaits a future to complete, or detects if a signal was received by either the global
/// or local shutdown channel. Shutdown only occurs if a signal was received, and any
/// errors received by either shutdown channel are ignored.
macro_rules! await_or_shutdown {
($(@save($id:ident, $tx:ident, $rx:ident))? $future:expr) => {{
let mut f = $future;
loop {
let use_shutdown = match shutdown.try_recv() {
Ok(_) => {
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
Err(broadcast::error::TryRecvError::Empty) => true,
Err(broadcast::error::TryRecvError::Lagged(_)) => true,
Err(broadcast::error::TryRecvError::Closed) => false,
};
let use_local_shutdown = match local_shutdown.try_recv() {
Ok(_) => {
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
Err(oneshot::error::TryRecvError::Empty) => true,
Err(oneshot::error::TryRecvError::Closed) => false,
};
if use_shutdown && use_local_shutdown {
tokio::select! {
x = shutdown.recv() => {
if x.is_err() {
continue;
}
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
x = &mut local_shutdown => {
if x.is_err() {
continue;
}
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
x = &mut f => { break x; }
}
} else if use_shutdown {
tokio::select! {
x = shutdown.recv() => {
if x.is_err() {
continue;
}
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
x = &mut f => { break x; }
}
} else if use_local_shutdown {
tokio::select! {
x = &mut local_shutdown => {
if x.is_err() {
continue;
}
terminate_connection!(@shutdown $(($id, $tx, $rx))?);
}
x = &mut f => { break x; }
}
} else {
break f.await;
}
}
}};
}
// Attempt to upgrade our handler for use with the connection going forward
let handler = match Weak::upgrade(&handler) {
Some(handler) => handler,
None => {
terminate_connection!(@fatal "Failed to setup connection because handler dropped");
}
};
// Attempt to upgrade our state for use with the connection going forward
let state = match Weak::upgrade(&state) {
Some(state) => state,
None => {
terminate_connection!(@fatal "Failed to setup connection because state dropped");
}
};
// Properly establish the connection's transport
debug!("[Conn {id}] Establishing full connection");
debug!("Establishing full connection using {transport:?}");
let mut connection = match Weak::upgrade(&verifier) {
Some(verifier) => {
match Connection::server(transport, verifier.as_ref(), keychain).await {
match await_or_shutdown!(Box::pin(Connection::server(
transport,
verifier.as_ref(),
keychain
))) {
Ok(connection) => connection,
Err(x) => {
terminate_connection!(@error "[Conn {id}] Failed to setup connection: {x}");
terminate_connection!(@fatal "Failed to setup connection: {x}");
}
}
}
None => {
terminate_connection!(@error "[Conn {id}] Verifier has been dropped");
}
};
// Attempt to upgrade our handler for use with the connection going forward
debug!("[Conn {id}] Preparing connection handler");
let handler = match Weak::upgrade(&handler) {
Some(handler) => handler,
None => {
terminate_connection!(@error "[Conn {id}] Handler has been dropped");
terminate_connection!(@fatal "Verifier has been dropped");
}
};
// Construct a queue of outgoing responses
let (tx, mut rx) = mpsc::channel::<Response<H::Response>>(1);
// Update our id to be the connection id
let id = connection.id();
// Create local data for the connection and then process it
debug!("[Conn {id}] Officially accepting connection");
let mut local_data = H::LocalData::default();
if let Err(x) = handler
.on_accept(ConnectionCtx {
connection_id: id,
local_data: &mut local_data,
})
.await
{
terminate_connection!(@error "[Conn {id}] Accepting connection failed: {x}");
if let Err(x) = await_or_shutdown!(handler.on_accept(ConnectionCtx {
connection_id: id,
local_data: &mut local_data
})) {
terminate_connection!(@fatal "[Conn {id}] Accepting connection failed: {x}");
}
let local_data = Arc::new(local_data);
let mut last_heartbeat = Instant::now();
// Restore our connection's channels if we have them, otherwise make new ones
let (tx, mut rx) = match state.connections.write().await.remove(&id) {
Some(conn) => match conn.shutdown_and_wait().await {
Some(x) => {
debug!("[Conn {id}] Marked as existing connection");
x
}
None => {
warn!("[Conn {id}] Existing connection with id, but channels not saved");
mpsc::channel::<Response<H::Response>>(1)
}
},
None => {
debug!("[Conn {id}] Marked as new connection");
mpsc::channel::<Response<H::Response>>(1)
}
};
// Store our connection details
state.connections.write().await.insert(id, connection_state);
debug!("[Conn {id}] Beginning read/write loop");
loop {
let ready = match connection
.ready(Interest::READABLE | Interest::WRITABLE)
.await
{
let ready = match await_or_shutdown!(
@save(id, tx, rx)
Box::pin(connection.ready(Interest::READABLE | Interest::WRITABLE))
) {
Ok(ready) => ready,
Err(x) => {
terminate_connection!(@error "[Conn {id}] Failed to examine ready state: {x}");
terminate_connection!(@error(tx, rx) "[Conn {id}] Failed to examine ready state: {x}");
}
};
@ -311,15 +484,14 @@ where
Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
Ok(request) => match request.to_typed_request() {
Ok(request) => {
let reply = ServerReply {
origin_id: request.id.clone(),
tx: tx.clone(),
};
let origin_id = request.id.clone();
let ctx = ServerCtx {
connection_id: id,
request,
reply: reply.clone(),
reply: ServerReply {
origin_id,
tx: tx.clone(),
},
local_data: Arc::clone(&local_data),
};
@ -344,11 +516,11 @@ where
}
},
Ok(None) => {
terminate_connection!(@debug "[Conn {id}] Connection closed");
terminate_connection!(@debug(tx, rx) "[Conn {id}] Connection closed");
}
Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
Err(x) => {
terminate_connection!(@error "[Conn {id}] {x}");
terminate_connection!(@error(tx, rx) "[Conn {id}] {x}");
}
}
}
@ -356,10 +528,20 @@ where
// If our socket is ready to be written to, we try to get the next item from
// the queue and process it
if ready.is_writable() {
// Send a heartbeat if we have exceeded our last time
if last_heartbeat.elapsed() >= heartbeat_duration {
trace!("[Conn {id}] Sending heartbeat via empty frame");
match connection.try_write_frame(Frame::empty()) {
Ok(()) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => error!("[Conn {id}] Send failed: {x}"),
}
last_heartbeat = Instant::now();
}
// If we get more data to write, attempt to write it, which will result in writing
// any queued bytes as well. Othewise, we attempt to flush any pending outgoing
// bytes that weren't sent earlier.
if let Ok(response) = rx.try_recv() {
else if let Ok(response) = rx.try_recv() {
// Log our message as a string, which can be expensive
if log_enabled!(Level::Trace) {
trace!(
@ -541,7 +723,7 @@ mod tests {
let err = task.await.unwrap_err();
assert!(
err.to_string().contains("Handler has been dropped"),
err.to_string().contains("handler dropped"),
"Unexpected error: {err}"
);
}
@ -610,6 +792,7 @@ mod tests {
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
#[derive(Debug)]
struct FakeTransport {
inner: InmemoryTransport,
fail_ready: Arc<AtomicBool>,
@ -678,7 +861,7 @@ mod tests {
let err = task.await.unwrap_err();
assert!(
err.to_string().contains("Failed to examine ready state"),
err.to_string().contains("targeted ready failure"),
"Unexpected error: {err}"
);
}
@ -722,7 +905,7 @@ mod tests {
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
ConnectionTask::build()
let _conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(keychain)
@ -748,4 +931,205 @@ mod tests {
let response = task.await.unwrap();
assert_eq!(response.payload, "hello");
}
#[test(tokio::test)]
async fn should_send_heartbeat_via_empty_frame_every_minimum_duration() {
let handler = Arc::new(TestServerHandler);
let state = Arc::new(ServerState::default());
let keychain = ServerKeychain::new();
let (t1, t2) = InmemoryTransport::pair(100);
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
let _conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(keychain)
.transport(t1)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.spawn();
// Spawn a task to handle establishing connection from client-side
let task = tokio::spawn(async move {
let mut client = Connection::client(t2, DummyAuthHandler)
.await
.expect("Fail to establish client-side connection");
// Verify we don't get a frame immediately
assert_eq!(
client.try_read_frame().unwrap_err().kind(),
io::ErrorKind::WouldBlock,
"got a frame early"
);
// Sleep more than our minimum heartbeat duration to ensure we get one
tokio::time::sleep(Duration::from_millis(250)).await;
assert_eq!(
client.read_frame().await.unwrap().unwrap(),
Frame::empty(),
"non-empty frame"
);
// Verify we don't get a frame immediately
assert_eq!(
client.try_read_frame().unwrap_err().kind(),
io::ErrorKind::WouldBlock,
"got a frame early"
);
// Sleep more than our minimum heartbeat duration to ensure we get one
tokio::time::sleep(Duration::from_millis(250)).await;
assert_eq!(
client.read_frame().await.unwrap().unwrap(),
Frame::empty(),
"non-empty frame"
);
});
task.await.unwrap();
}
#[test(tokio::test)]
async fn should_be_able_to_shutdown_while_establishing_connection() {
let handler = Arc::new(TestServerHandler);
let state = Arc::new(ServerState::default());
let keychain = ServerKeychain::new();
let (t1, _t2) = InmemoryTransport::pair(100);
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
let conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(keychain)
.transport(t1)
.shutdown(shutdown_rx)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.spawn();
// Shutdown server connection task while it is establishing a full connection with the
// client, verifying that we do not get an error in return
shutdown_tx
.send(())
.expect("Failed to send shutdown signal");
conn.await.unwrap();
}
#[test(tokio::test)]
async fn should_be_able_to_shutdown_while_accepting_connection() {
struct HangingAcceptServerHandler;
#[async_trait]
impl ServerHandler for HangingAcceptServerHandler {
type Request = ();
type Response = ();
type LocalData = ();
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
// Wait "forever" so we can ensure that we fail at this step
tokio::time::sleep(Duration::MAX).await;
Err(io::Error::new(io::ErrorKind::Other, "bad accept"))
}
async fn on_request(
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
unreachable!();
}
}
let handler = Arc::new(HangingAcceptServerHandler);
let state = Arc::new(ServerState::default());
let keychain = ServerKeychain::new();
let (t1, t2) = InmemoryTransport::pair(100);
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
let conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(keychain)
.transport(t1)
.shutdown(shutdown_rx)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.spawn();
// Spawn a task to handle the client-side establishment of a full connection
let _client_task = tokio::spawn(Connection::client(t2, DummyAuthHandler));
// Shutdown server connection task while it is accepting the connection, verifying that we
// do not get an error in return
shutdown_tx
.send(())
.expect("Failed to send shutdown signal");
conn.await.unwrap();
}
#[test(tokio::test)]
async fn should_be_able_to_shutdown_while_waiting_for_connection_to_be_ready() {
struct AcceptServerHandler {
tx: mpsc::Sender<()>,
}
#[async_trait]
impl ServerHandler for AcceptServerHandler {
type Request = ();
type Response = ();
type LocalData = ();
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
self.tx.send(()).await.unwrap();
Ok(())
}
async fn on_request(
&self,
_: ServerCtx<Self::Request, Self::Response, Self::LocalData>,
) {
unreachable!();
}
}
let (tx, mut rx) = mpsc::channel(100);
let handler = Arc::new(AcceptServerHandler { tx });
let state = Arc::new(ServerState::default());
let keychain = ServerKeychain::new();
let (t1, t2) = InmemoryTransport::pair(100);
let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
let verifier = Arc::new(Verifier::none());
let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
let conn = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(keychain)
.transport(t1)
.shutdown(shutdown_rx)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.spawn();
// Spawn a task to handle the client-side establishment of a full connection
let _client_task = tokio::spawn(Connection::client(t2, DummyAuthHandler));
// Wait to ensure we complete the accept call first
let _ = rx.recv().await;
// Shutdown server connection task while it is accepting the connection, verifying that we
// do not get an error in return
shutdown_tx
.send(())
.expect("Failed to send shutdown signal");
conn.await.unwrap();
}
}

@ -1,23 +1,21 @@
use super::ServerState;
use crate::common::AsAny;
use log::*;
use std::{
future::Future,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use tokio::sync::broadcast;
use tokio::task::{JoinError, JoinHandle};
/// Interface to engage with a server instance
/// Interface to engage with a server instance.
pub trait ServerRef: AsAny + Send {
/// Returns true if the server is no longer running
/// Returns true if the server is no longer running.
fn is_finished(&self) -> bool;
/// Kills the internal task processing new inbound requests
fn abort(&self);
/// Sends a shutdown signal to the server.
fn shutdown(&self);
fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>>
where
@ -64,7 +62,7 @@ impl dyn ServerRef {
/// Represents a generic reference to a server
pub struct GenericServerRef {
pub(crate) state: Arc<ServerState>,
pub(crate) shutdown: broadcast::Sender<()>,
pub(crate) task: JoinHandle<()>,
}
@ -74,16 +72,8 @@ impl ServerRef for GenericServerRef {
self.task.is_finished()
}
fn abort(&self) {
self.task.abort();
let state = Arc::clone(&self.state);
tokio::spawn(async move {
for (id, connection) in state.connections.read().await.iter() {
debug!("Aborting connection {}", id);
connection.abort();
}
});
fn shutdown(&self) {
let _ = self.shutdown.send(());
}
fn wait(self) -> Pin<Box<dyn Future<Output = io::Result<()>>>>

@ -29,7 +29,7 @@ impl ServerRef for TcpServerRef {
self.inner.is_finished()
}
fn abort(&self) {
self.inner.abort();
fn shutdown(&self) {
self.inner.shutdown();
}
}

@ -28,7 +28,7 @@ impl ServerRef for UnixSocketServerRef {
self.inner.is_finished()
}
fn abort(&self) {
self.inner.abort();
fn shutdown(&self) {
self.inner.shutdown();
}
}

@ -28,7 +28,7 @@ impl ServerRef for WindowsPipeServerRef {
self.inner.is_finished()
}
fn abort(&self) {
self.inner.abort();
fn shutdown(&self) {
self.inner.shutdown();
}
}

@ -1,37 +1,70 @@
use super::ConnectionTask;
use crate::common::{authentication::Keychain, Backup, ConnectionId};
use std::collections::HashMap;
use tokio::sync::{oneshot, RwLock};
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::task::JoinHandle;
/// Contains all top-level state for the server
pub struct ServerState {
/// Mapping of connection ids to their transports
pub connections: RwLock<HashMap<ConnectionId, ConnectionTask>>,
pub struct ServerState<T> {
/// Mapping of connection ids to their tasks.
pub connections: RwLock<HashMap<ConnectionId, ConnectionState<T>>>,
/// Mapping of connection ids to (OTP, backup)
pub keychain: Keychain<oneshot::Receiver<Backup>>,
}
impl ServerState {
impl<T> ServerState<T> {
pub fn new() -> Self {
Self {
connections: RwLock::new(HashMap::new()),
keychain: Keychain::new(),
}
}
/// Returns true if there is at least one active connection
pub async fn has_active_connections(&self) -> bool {
self.connections
.read()
.await
.values()
.any(|task| !task.is_finished())
}
}
impl Default for ServerState {
impl<T> Default for ServerState<T> {
fn default() -> Self {
Self::new()
}
}
pub struct ConnectionState<T> {
shutdown_tx: oneshot::Sender<()>,
task: JoinHandle<Option<(mpsc::Sender<T>, mpsc::Receiver<T>)>>,
}
impl<T: Send + 'static> ConnectionState<T> {
/// Creates new state with appropriate channels, returning
/// (shutdown receiver, channel sender, state).
#[allow(clippy::type_complexity)]
pub fn channel() -> (
oneshot::Receiver<()>,
oneshot::Sender<(mpsc::Sender<T>, mpsc::Receiver<T>)>,
Self,
) {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (channel_tx, channel_rx) = oneshot::channel();
(
shutdown_rx,
channel_tx,
Self {
shutdown_tx,
task: tokio::spawn(async move {
match channel_rx.await {
Ok(x) => Some(x),
Err(_) => None,
}
}),
},
)
}
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
pub async fn shutdown_and_wait(self) -> Option<(mpsc::Sender<T>, mpsc::Receiver<T>)> {
let _ = self.shutdown_tx.send(());
self.task.await.unwrap()
}
}

@ -1,6 +1,6 @@
use async_trait::async_trait;
use distant_net::boxed_connect_handler;
use distant_net::client::{Client, ReconnectStrategy};
use distant_net::client::Client;
use distant_net::common::authentication::{DummyAuthHandler, Verifier};
use distant_net::common::{Destination, InmemoryTransport, Map, OneshotListener};
use distant_net::manager::{Config, ManagerClient, ManagerServer};
@ -43,7 +43,6 @@ async fn should_be_able_to_establish_a_single_connection_and_communicate_with_a_
let client = Client::build()
.auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.connector(t1)
.connect_untyped()
.await?;
@ -61,7 +60,6 @@ async fn should_be_able_to_establish_a_single_connection_and_communicate_with_a_
info!("Connecting to manager");
let mut client: ManagerClient = Client::build()
.auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.connector(t1)
.connect()
.await

@ -1,5 +1,5 @@
use async_trait::async_trait;
use distant_net::client::{Client, ReconnectStrategy};
use distant_net::client::Client;
use distant_net::common::authentication::{DummyAuthHandler, Verifier};
use distant_net::common::{InmemoryTransport, OneshotListener};
use distant_net::server::{Server, ServerCtx, ServerHandler};
@ -38,7 +38,6 @@ async fn should_be_able_to_send_and_receive_typed_payloads_between_client_and_se
let mut client: Client<(u8, String), String> = Client::build()
.auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.connector(t1)
.connect()
.await

@ -1,5 +1,5 @@
use async_trait::async_trait;
use distant_net::client::{Client, ReconnectStrategy};
use distant_net::client::Client;
use distant_net::common::authentication::{DummyAuthHandler, Verifier};
use distant_net::common::{InmemoryTransport, OneshotListener, Request};
use distant_net::server::{Server, ServerCtx, ServerHandler};
@ -38,7 +38,6 @@ async fn should_be_able_to_send_and_receive_untyped_payloads_between_client_and_
let mut client = Client::build()
.auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.connector(t1)
.connect_untyped()
.await

@ -2,7 +2,7 @@
name = "distant-ssh2"
description = "Library to enable native ssh-2 protocol for use with distant sessions"
categories = ["network-programming"]
version = "0.20.0-alpha.2"
version = "0.20.0-alpha.3"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"
@ -20,7 +20,7 @@ async-compat = "0.2.1"
async-once-cell = "0.4.2"
async-trait = "0.1.58"
derive_more = { version = "0.99.17", default-features = false, features = ["display", "error"] }
distant-core = { version = "=0.20.0-alpha.2", path = "../distant-core" }
distant-core = { version = "=0.20.0-alpha.3", path = "../distant-core" }
futures = "0.3.25"
hex = "0.4.3"
log = "0.4.17"

@ -694,12 +694,11 @@ impl DistantApi for SshDistantApi {
cmd: String,
environment: Environment,
current_dir: Option<PathBuf>,
persist: bool,
pty: Option<PtySize>,
) -> io::Result<ProcessId> {
debug!(
"[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, persist: {}, pty: {:?}}}",
ctx.connection_id, cmd, environment, current_dir, persist, pty
"[Conn {}] Spawning {} {{environment: {:?}, current_dir: {:?}, pty: {:?}}}",
ctx.connection_id, cmd, environment, current_dir, pty
);
let global_processes = Arc::downgrade(&self.processes);
@ -744,12 +743,6 @@ impl DistantApi for SshDistantApi {
}
};
// If the process will be killed when the connection ends, we want to add it
// to our local data
if !persist {
ctx.local_data.processes.write().await.insert(id);
}
self.processes.write().await.insert(
id,
Process {

@ -7,7 +7,7 @@ use async_trait::async_trait;
use distant_core::{
data::Environment,
net::{
client::{Client, ReconnectStrategy},
client::{Client, ClientConfig},
common::authentication::{AuthHandlerMap, DummyAuthHandler, Verifier},
common::{InmemoryTransport, OneshotListener},
server::{Server, ServerRef},
@ -574,7 +574,7 @@ impl Ssh {
debug!("Attempting to connect to distant server @ {}", addr);
match Client::tcp(addr)
.auth_handler(AuthHandlerMap::new().with_static_key(key.clone()))
.timeout(timeout)
.connect_timeout(timeout)
.connect()
.await
{
@ -646,7 +646,7 @@ impl Ssh {
);
// Close out ssh client by killing the internal server and client
server.abort();
server.shutdown();
client.abort();
let _ = client.wait().await;
@ -718,8 +718,8 @@ impl Ssh {
.start(OneshotListener::from_value(t2))?;
let client = Client::build()
.auth_handler(DummyAuthHandler)
.config(ClientConfig::default().with_maximum_silence_duration())
.connector(t1)
.reconnect_strategy(ReconnectStrategy::Fail)
.connect()
.await?;
Ok((client, server))

@ -1217,7 +1217,6 @@ async fn proc_spawn_should_not_fail_even_if_process_not_found(
/* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1239,7 +1238,6 @@ async fn proc_spawn_should_return_id_of_spawned_process(#[future] client: Ctx<Di
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1267,7 +1265,6 @@ async fn proc_spawn_should_send_back_stdout_periodically_when_available(
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1303,7 +1300,6 @@ async fn proc_spawn_should_send_back_stderr_periodically_when_available(
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1333,7 +1329,6 @@ async fn proc_spawn_should_send_done_signal_when_completed(#[future] client: Ctx
format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1355,7 +1350,6 @@ async fn proc_spawn_should_clear_process_from_state_when_killed(
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1380,7 +1374,6 @@ async fn proc_kill_should_fail_if_process_not_running(#[future] client: Ctx<Dist
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1408,7 +1401,6 @@ async fn proc_stdin_should_fail_if_process_not_running(#[future] client: Ctx<Dis
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1444,7 +1436,6 @@ async fn proc_stdin_should_send_stdin_to_process(#[future] client: Ctx<DistantCl
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await

@ -1199,7 +1199,6 @@ async fn proc_spawn_should_fail_if_process_not_found(
/* cmd */ DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1223,7 +1222,6 @@ async fn proc_spawn_should_return_id_of_spawned_process(
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1251,7 +1249,6 @@ async fn proc_spawn_should_send_back_stdout_periodically_when_available(
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1287,7 +1284,6 @@ async fn proc_spawn_should_send_back_stderr_periodically_when_available(
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1319,7 +1315,6 @@ async fn proc_spawn_should_send_done_signal_when_completed(
format!("{} {} 0.1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1341,7 +1336,6 @@ async fn proc_spawn_should_clear_process_from_state_when_killed(
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1368,7 +1362,6 @@ async fn proc_kill_should_fail_if_process_not_running(
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1398,7 +1391,6 @@ async fn proc_stdin_should_fail_if_process_not_running(
format!("{} {} 1", *SCRIPT_RUNNER, SLEEP_SH.to_str().unwrap()),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await
@ -1434,7 +1426,6 @@ async fn proc_stdin_should_send_stdin_to_process(#[future] launched_client: Ctx<
),
/* environment */ Environment::new(),
/* current_dir */ None,
/* persist */ false,
/* pty */ None,
)
.await

@ -1,6 +1,6 @@
use crate::config::NetworkConfig;
use async_trait::async_trait;
use distant_core::net::client::{Client as NetClient, ReconnectStrategy};
use distant_core::net::client::{Client as NetClient, ClientConfig, ReconnectStrategy};
use distant_core::net::common::authentication::msg::*;
use distant_core::net::common::authentication::{
AuthHandler, AuthMethodHandler, PromptAuthMethodHandler, SingleAuthHandler,
@ -61,12 +61,15 @@ impl<T: AuthHandler + Clone> Client<T> {
for path in self.network.to_unix_socket_path_candidates() {
match NetClient::unix_socket(path)
.auth_handler(self.auth_handler.clone())
.reconnect_strategy(ReconnectStrategy::ExponentialBackoff {
base: Duration::from_secs(1),
factor: 2.0,
max_duration: None,
max_retries: None,
timeout: None,
.config(ClientConfig {
reconnect_strategy: ReconnectStrategy::ExponentialBackoff {
base: Duration::from_secs(1),
factor: 2.0,
max_duration: Some(Duration::from_secs(10)),
max_retries: None,
timeout: None,
},
..Default::default()
})
.connect()
.await
@ -100,12 +103,15 @@ impl<T: AuthHandler + Clone> Client<T> {
for name in self.network.to_windows_pipe_name_candidates() {
match NetClient::local_windows_pipe(name)
.auth_handler(self.auth_handler.clone())
.reconnect_strategy(ReconnectStrategy::ExponentialBackoff {
base: Duration::from_secs(1),
factor: 2.0,
max_duration: None,
max_retries: None,
timeout: None,
.config(ClientConfig {
reconnect_strategy: ReconnectStrategy::ExponentialBackoff {
base: Duration::from_secs(1),
factor: 2.0,
max_duration: Some(Duration::from_secs(10)),
max_retries: None,
timeout: None,
},
..Default::default()
})
.connect()
.await

@ -134,11 +134,6 @@ pub enum ClientSubcommand {
#[clap(flatten)]
network: NetworkConfig,
/// If provided, will run in persist mode, meaning that the process will not be killed if the
/// client disconnects from the server
#[clap(long)]
persist: bool,
/// If provided, will run LSP in a pty
#[clap(long)]
pty: bool,
@ -215,11 +210,6 @@ pub enum ClientSubcommand {
#[clap(long, default_value_t)]
environment: Environment,
/// If provided, will run in persist mode, meaning that the process will not be killed if the
/// client disconnects from the server
#[clap(long)]
persist: bool,
/// Optional command to run instead of $SHELL
cmd: Option<String>,
},
@ -294,14 +284,12 @@ impl ClientSubcommand {
cmd,
environment,
current_dir,
persist,
pty,
} => {
debug!("Special request spawning {:?}", cmd);
let mut proc = RemoteCommand::new()
.environment(environment)
.current_dir(current_dir)
.persist(persist)
.pty(pty)
.spawn(channel.into_client().into_channel(), cmd.as_str())
.await
@ -568,7 +556,6 @@ impl ClientSubcommand {
Self::Lsp {
connection,
network,
persist,
pty,
cmd,
..
@ -592,12 +579,9 @@ impl ClientSubcommand {
format!("Failed to open channel to connection {connection_id}")
})?;
debug!(
"Spawning LSP server (persist = {}, pty = {}): {}",
persist, pty, cmd
);
debug!("Spawning LSP server (pty = {}): {}", pty, cmd);
Lsp::new(channel.into_client().into_channel())
.spawn(cmd, persist, pty)
.spawn(cmd, pty)
.await?;
}
Self::Repl {
@ -869,7 +853,6 @@ impl ClientSubcommand {
connection,
network,
environment,
persist,
cmd,
..
} => {
@ -893,13 +876,12 @@ impl ClientSubcommand {
})?;
debug!(
"Spawning shell (environment = {:?}, persist = {}): {}",
"Spawning shell (environment = {:?}): {}",
environment,
persist,
cmd.as_deref().unwrap_or(r"$SHELL")
);
Shell::new(channel.into_client().into_channel())
.spawn(cmd, environment, persist)
.spawn(cmd, environment)
.await?;
}
}

@ -11,10 +11,9 @@ impl Lsp {
Self(channel)
}
pub async fn spawn(self, cmd: impl Into<String>, persist: bool, pty: bool) -> CliResult {
pub async fn spawn(self, cmd: impl Into<String>, pty: bool) -> CliResult {
let cmd = cmd.into();
let mut proc = RemoteLspCommand::new()
.persist(persist)
.pty(if pty {
terminal_size().map(|(Width(width), Height(height))| {
PtySize::from_rows_and_cols(height, width)

@ -25,7 +25,6 @@ impl Shell {
mut self,
cmd: impl Into<Option<String>>,
mut environment: Environment,
persist: bool,
) -> CliResult {
// Automatically add TERM=xterm-256color if not specified
if !environment.contains_key("TERM") {
@ -55,7 +54,6 @@ impl Shell {
};
let mut proc = RemoteCommand::new()
.persist(persist)
.environment(environment)
.pty(
terminal_size()

@ -1,6 +1,6 @@
use crate::config::ClientLaunchConfig;
use async_trait::async_trait;
use distant_core::net::client::{Client, ReconnectStrategy, UntypedClient};
use distant_core::net::client::{Client, ClientConfig, ReconnectStrategy, UntypedClient};
use distant_core::net::common::authentication::msg::*;
use distant_core::net::common::authentication::{
AuthHandler, Authenticator, DynAuthHandler, ProxyAuthHandler, SingleAuthHandler,
@ -210,14 +210,17 @@ impl DistantConnectHandler {
match Client::tcp(addr)
.auth_handler(DynAuthHandler::from(&mut auth_handler))
.reconnect_strategy(ReconnectStrategy::ExponentialBackoff {
base: Duration::from_secs(1),
factor: 2.0,
max_duration: None,
max_retries: None,
timeout: None,
.config(ClientConfig {
reconnect_strategy: ReconnectStrategy::ExponentialBackoff {
base: Duration::from_secs(1),
factor: 2.0,
max_duration: Some(Duration::from_secs(10)),
max_retries: None,
timeout: None,
},
..Default::default()
})
.timeout(Duration::from_secs(180))
.connect_timeout(Duration::from_secs(180))
.connect_untyped()
.await
{

@ -84,7 +84,6 @@ async fn should_support_json_to_execute_program_and_return_exit_status(
"payload": {
"type": "proc_spawn",
"cmd": cmd,
"persist": false,
"pty": null,
},
});
@ -109,7 +108,6 @@ async fn should_support_json_to_capture_and_print_stdout(mut json_repl: CtxComma
"payload": {
"type": "proc_spawn",
"cmd": cmd,
"persist": false,
"pty": null,
},
});
@ -148,7 +146,6 @@ async fn should_support_json_to_capture_and_print_stderr(mut json_repl: CtxComma
"payload": {
"type": "proc_spawn",
"cmd": cmd,
"persist": false,
"pty": null,
},
});
@ -187,7 +184,6 @@ async fn should_support_json_to_forward_stdin_to_remote_process(mut json_repl: C
"payload": {
"type": "proc_spawn",
"cmd": cmd,
"persist": false,
"pty": null,
},
});
@ -271,7 +267,6 @@ async fn should_support_json_output_for_error(mut json_repl: CtxCommand<Repl>) {
"payload": {
"type": "proc_spawn",
"cmd": DOES_NOT_EXIST_BIN.to_str().unwrap().to_string(),
"persist": false,
"pty": null,
},
});

Loading…
Cancel
Save