Rewrite to support custom authentication, handshakes for encryption/compression, and reconnecting (#146)

pull/156/head
Chip Senkbeil 1 year ago committed by GitHub
parent 7d1b3ba6f0
commit 4798b67dfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
[profile.ci]
fail-fast = false
retries = 2
retries = 4
slow-timeout = { period = "60s", terminate-after = 3 }
status-level = "fail"
final-status-level = "fail"

@ -168,6 +168,7 @@ jobs:
run: cargo build --release
- name: Run CLI tests (all features)
run: cargo nextest run --profile ci --release --all-features
if: matrix.os != 'windows-latest'
ssh-launch-tests:
name: "Test ssh launch using Rust ${{ matrix.rust }} on ${{ matrix.os }}"
runs-on: ${{ matrix.os }}

248
Cargo.lock generated

@ -2,6 +2,12 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "adler"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aead"
version = "0.5.0"
@ -21,6 +27,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "ansi_term"
version = "0.12.1"
@ -371,6 +386,18 @@ dependencies = [
"zeroize",
]
[[package]]
name = "chrono"
version = "0.4.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f"
dependencies = [
"iana-time-zone",
"num-integer",
"num-traits",
"winapi",
]
[[package]]
name = "cipher"
version = "0.4.3"
@ -430,6 +457,16 @@ dependencies = [
"os_str_bytes",
]
[[package]]
name = "codespan-reporting"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
dependencies = [
"termcolor",
"unicode-width",
]
[[package]]
name = "combine"
version = "4.6.4"
@ -505,6 +542,15 @@ dependencies = [
"libc",
]
[[package]]
name = "crc32fast"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.4"
@ -583,6 +629,50 @@ dependencies = [
"phf 0.11.1",
]
[[package]]
name = "cxx"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97abf9f0eca9e52b7f81b945524e76710e6cb2366aead23b7d4fbf72e281f888"
dependencies = [
"cc",
"cxxbridge-flags",
"cxxbridge-macro",
"link-cplusplus",
]
[[package]]
name = "cxx-build"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cc32cc5fea1d894b77d269ddb9f192110069a8a9c1f1d441195fba90553dea3"
dependencies = [
"cc",
"codespan-reporting",
"once_cell",
"proc-macro2",
"quote",
"scratch",
"syn",
]
[[package]]
name = "cxxbridge-flags"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ca220e4794c934dc6b1207c3b42856ad4c302f2df1712e9f8d2eec5afaacf1f"
[[package]]
name = "cxxbridge-macro"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b846f081361125bfc8dc9d3940c84e1fd83ba54bbca7b17cd29483c828be0704"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "deltae"
version = "0.3.0"
@ -709,7 +799,7 @@ dependencies = [
[[package]]
name = "distant"
version = "0.19.0"
version = "0.20.0"
dependencies = [
"anyhow",
"assert_cmd",
@ -723,6 +813,7 @@ dependencies = [
"directories",
"distant-core",
"distant-ssh2",
"env_logger",
"flexi_logger",
"fork",
"indoc",
@ -740,6 +831,7 @@ dependencies = [
"tabled",
"terminal_size 0.2.1",
"termwiz",
"test-log",
"tokio",
"toml_edit",
"which",
@ -750,7 +842,7 @@ dependencies = [
[[package]]
name = "distant-core"
version = "0.19.0"
version = "0.20.0"
dependencies = [
"assert_fs",
"async-trait",
@ -759,7 +851,7 @@ dependencies = [
"clap",
"derive_more",
"distant-net",
"flexi_logger",
"env_logger",
"futures",
"grep",
"hex",
@ -780,6 +872,7 @@ dependencies = [
"serde_json",
"shell-words",
"strum",
"test-log",
"tokio",
"tokio-util",
"walkdir",
@ -789,13 +882,15 @@ dependencies = [
[[package]]
name = "distant-net"
version = "0.19.0"
version = "0.20.0"
dependencies = [
"async-trait",
"bytes",
"chacha20poly1305",
"derive_more",
"futures",
"dyn-clone",
"env_logger",
"flate2",
"hex",
"hkdf",
"log",
@ -807,14 +902,15 @@ dependencies = [
"serde",
"serde_bytes",
"sha2 0.10.2",
"strum",
"tempfile",
"test-log",
"tokio",
"tokio-util",
]
[[package]]
name = "distant-ssh2"
version = "0.19.0"
version = "0.20.0"
dependencies = [
"anyhow",
"assert_fs",
@ -824,7 +920,7 @@ dependencies = [
"derive_more",
"distant-core",
"dunce",
"flexi_logger",
"env_logger",
"futures",
"hex",
"indoc",
@ -837,6 +933,7 @@ dependencies = [
"serde",
"shell-words",
"smol",
"test-log",
"tokio",
"typed-path",
"wezterm-ssh",
@ -927,6 +1024,19 @@ dependencies = [
"encoding_rs",
]
[[package]]
name = "env_logger"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c90bf5f19754d10198ccb95b70664fc925bd1fc090a0fd9a6ebc54acc8cd6272"
dependencies = [
"atty",
"humantime",
"log",
"regex",
"termcolor",
]
[[package]]
name = "err-derive"
version = "0.3.1"
@ -1028,21 +1138,31 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f82b0f4c27ad9f8bfd1f3208d882da2b09c301bc1c828fd3a00d0216d2fbbff6"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "flexi_logger"
version = "0.23.0"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8790f70905b203171c21060222f18f1df5cba07317860215b7880b32aaef290"
checksum = "99659bcfd52cfece972bd00acb9dba7028094d47e699ea8b193b9aaebd5c362b"
dependencies = [
"ansi_term",
"atty",
"chrono",
"glob",
"lazy_static",
"log",
"regex",
"rustversion",
"thiserror",
"time",
]
[[package]]
@ -1390,6 +1510,36 @@ dependencies = [
"digest 0.10.3",
]
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "iana-time-zone"
version = "0.1.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64c122667b287044802d6ce17ee2ddf13207ed924c712de9a66a5814d5b64765"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"winapi",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca"
dependencies = [
"cxx",
"cxx-build",
]
[[package]]
name = "ignore"
version = "0.4.18"
@ -1588,6 +1738,15 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "link-cplusplus"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9272ab7b96c9046fbc5bc56c06c117cb639fe2d509df0c421cad82d2915cf369"
dependencies = [
"cc",
]
[[package]]
name = "linux-raw-sys"
version = "0.0.46"
@ -1649,6 +1808,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34"
dependencies = [
"adler",
]
[[package]]
name = "mio"
version = "0.8.3"
@ -1701,9 +1869,9 @@ checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be"
[[package]]
name = "notify"
version = "5.0.0-pre.15"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "553f9844ad0b0824605c20fb55a661679782680410abfb1a8144c2e7e437e7a7"
checksum = "ed2c66da08abae1c024c01d635253e402341b4060a12e99b31c7594063bf490a"
dependencies = [
"bitflags",
"crossbeam-channel",
@ -1738,6 +1906,16 @@ dependencies = [
"syn",
]
[[package]]
name = "num-integer"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.15"
@ -1757,15 +1935,6 @@ dependencies = [
"libc",
]
[[package]]
name = "num_threads"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44"
dependencies = [
"libc",
]
[[package]]
name = "once_cell"
version = "1.13.0"
@ -2502,6 +2671,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "scratch"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8132065adcfd6e02db789d9285a0deb2f3fcb04002865ab67d5fb103533898"
[[package]]
name = "sec1"
version = "0.3.0"
@ -3019,6 +3194,17 @@ dependencies = [
"winapi",
]
[[package]]
name = "test-log"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38f0c854faeb68a048f0f2dc410c5ddae3bf83854ef0e4977d58306a5edef50e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "textwrap"
version = "0.15.0"
@ -3054,24 +3240,6 @@ dependencies = [
"once_cell",
]
[[package]]
name = "time"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2702e08a7a860f005826c6815dcac101b19b5eb330c27fe4a5928fec1d20ddd"
dependencies = [
"itoa",
"libc",
"num_threads",
"time-macros",
]
[[package]]
name = "time-macros"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42657b1a6f4d817cda8e7a0ace261fe0cc946cf3a80314390b22cc61ae080792"
[[package]]
name = "tokio"
version = "1.20.1"

@ -3,7 +3,7 @@ name = "distant"
description = "Operate on a remote computer through file and process manipulation"
categories = ["command-line-utilities"]
keywords = ["cli"]
version = "0.19.0"
version = "0.20.0"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"
@ -32,9 +32,9 @@ clap_complete = "3.2.3"
config = { version = "0.13.2", default-features = false, features = ["toml"] }
derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error", "is_variant"] }
dialoguer = { version = "0.10.2", default-features = false }
distant-core = { version = "=0.19.0", path = "distant-core", features = ["clap", "schemars"] }
distant-core = { version = "=0.20.0", path = "distant-core", features = ["clap", "schemars"] }
directories = "4.0.1"
flexi_logger = "0.23.0"
flexi_logger = "0.24.1"
indoc = "1.0.7"
log = "0.4.17"
once_cell = "1.13.0"
@ -54,7 +54,7 @@ winsplit = "0.1.0"
whoami = "1.2.1"
# Optional native SSH functionality
distant-ssh2 = { version = "=0.19.0", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true }
distant-ssh2 = { version = "=0.20.0", path = "distant-ssh2", default-features = false, features = ["serde"], optional = true }
[target.'cfg(unix)'.dependencies]
fork = "0.1.19"
@ -66,6 +66,8 @@ windows-service = "0.5.0"
[dev-dependencies]
assert_cmd = "2.0.4"
assert_fs = "1.0.7"
env_logger = "0.9.1"
indoc = "1.0.7"
predicates = "2.1.1"
rstest = "0.15.0"
test-log = "0.2.11"

@ -3,7 +3,7 @@ name = "distant-core"
description = "Core library for distant, enabling operation on a remote computer through file and process manipulation"
categories = ["network-programming"]
keywords = ["api", "async"]
version = "0.19.0"
version = "0.20.0"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"
@ -19,13 +19,13 @@ async-trait = "0.1.57"
bitflags = "1.3.2"
bytes = "1.2.1"
derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] }
distant-net = { version = "=0.19.0", path = "../distant-net" }
distant-net = { version = "=0.20.0", path = "../distant-net" }
futures = "0.3.21"
grep = "0.2.10"
hex = "0.4.3"
ignore = "0.4.18"
log = "0.4.17"
notify = { version = "=5.0.0-pre.15", features = ["serde"] }
notify = { version = "5.0.0", features = ["serde"] }
num_cpus = "1.13.1"
once_cell = "1.13.0"
portable-pty = "0.7.0"
@ -48,7 +48,8 @@ schemars = { version = "0.8.10", optional = true }
[dev-dependencies]
assert_fs = "1.0.7"
flexi_logger = "0.23.0"
env_logger = "0.9.1"
indoc = "1.0.7"
predicates = "2.1.1"
rstest = "0.15.0"
test-log = "0.2.11"

@ -3,10 +3,11 @@ use crate::{
Capabilities, ChangeKind, DirEntry, Environment, Error, Metadata, ProcessId, PtySize,
SearchId, SearchQuery, SystemInfo,
},
ConnectionId, DistantMsg, DistantRequestData, DistantResponseData,
DistantMsg, DistantRequestData, DistantResponseData,
};
use async_trait::async_trait;
use distant_net::{Reply, Server, ServerConfig, ServerCtx};
use distant_net::common::ConnectionId;
use distant_net::server::{ConnectionCtx, Reply, ServerCtx, ServerHandler};
use log::*;
use std::{io, path::PathBuf, sync::Arc};
@ -23,15 +24,15 @@ pub struct DistantCtx<T> {
pub local_data: Arc<T>,
}
/// Represents a server that leverages an API compliant with `distant`
pub struct DistantApiServer<T, D>
/// Represents a [`ServerHandler`] that leverages an API compliant with `distant`
pub struct DistantApiServerHandler<T, D>
where
T: DistantApi<LocalData = D>,
{
api: T,
}
impl<T, D> DistantApiServer<T, D>
impl<T, D> DistantApiServerHandler<T, D>
where
T: DistantApi<LocalData = D>,
{
@ -40,11 +41,11 @@ where
}
}
impl DistantApiServer<LocalDistantApi, <LocalDistantApi as DistantApi>::LocalData> {
impl DistantApiServerHandler<LocalDistantApi, <LocalDistantApi as DistantApi>::LocalData> {
/// Creates a new server using the [`LocalDistantApi`] implementation
pub fn local(config: ServerConfig) -> io::Result<Self> {
pub fn local() -> io::Result<Self> {
Ok(Self {
api: LocalDistantApi::initialize(config)?,
api: LocalDistantApi::initialize()?,
})
}
}
@ -63,15 +64,12 @@ fn unsupported<T>(label: &str) -> io::Result<T> {
pub trait DistantApi {
type LocalData: Send + Sync;
/// Returns config associated with API server
fn config(&self) -> ServerConfig {
ServerConfig::default()
}
/// Invoked whenever a new connection is established, providing a mutable reference to the
/// newly-created local data. This is a way to support modifying local data before it is used.
#[allow(unused_variables)]
async fn on_accept(&self, local_data: &mut Self::LocalData) {}
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
Ok(())
}
/// Retrieves information about the server's capabilities.
///
@ -420,7 +418,7 @@ pub trait DistantApi {
}
#[async_trait]
impl<T, D> Server for DistantApiServer<T, D>
impl<T, D> ServerHandler for DistantApiServerHandler<T, D>
where
T: DistantApi<LocalData = D> + Send + Sync,
D: Send + Sync,
@ -429,14 +427,9 @@ where
type Response = DistantMsg<DistantResponseData>;
type LocalData = D;
/// Overridden to leverage [`DistantApi`] implementation of `config`
fn config(&self) -> ServerConfig {
T::config(&self.api)
}
/// Overridden to leverage [`DistantApi`] implementation of `on_accept`
async fn on_accept(&self, local_data: &mut Self::LocalData) {
T::on_accept(&self.api, local_data).await
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
T::on_accept(&self.api, ctx).await
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
@ -518,7 +511,7 @@ where
/// Processes an incoming request
async fn handle_request<T, D>(
server: &DistantApiServer<T, D>,
server: &DistantApiServerHandler<T, D>,
ctx: DistantCtx<D>,
request: DistantRequestData,
) -> DistantResponseData

@ -6,7 +6,7 @@ use crate::{
DistantApi, DistantCtx,
};
use async_trait::async_trait;
use distant_net::ServerConfig;
use distant_net::server::ConnectionCtx;
use log::*;
use std::{
io,
@ -26,15 +26,13 @@ use state::*;
/// impementation of the API instead of a proxy to another machine as seen with
/// implementations on top of SSH and other protocol
pub struct LocalDistantApi {
config: ServerConfig,
state: GlobalState,
}
impl LocalDistantApi {
/// Initialize the api instance
pub fn initialize(config: ServerConfig) -> io::Result<Self> {
pub fn initialize() -> io::Result<Self> {
Ok(Self {
config,
state: GlobalState::initialize()?,
})
}
@ -44,14 +42,11 @@ impl LocalDistantApi {
impl DistantApi for LocalDistantApi {
type LocalData = ConnectionState;
fn config(&self) -> ServerConfig {
self.config.clone()
}
/// Injects the global channels into the local connection
async fn on_accept(&self, local_data: &mut Self::LocalData) {
local_data.process_channel = self.state.process.clone_channel();
local_data.watcher_channel = self.state.watcher.clone_channel();
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
ctx.local_data.process_channel = self.state.process.clone_channel();
ctx.local_data.watcher_channel = self.state.watcher.clone_channel();
Ok(())
}
async fn capabilities(&self, ctx: DistantCtx<Self::LocalData>) -> io::Result<Capabilities> {
@ -511,10 +506,11 @@ mod tests {
use super::*;
use crate::data::DistantResponseData;
use assert_fs::prelude::*;
use distant_net::Reply;
use distant_net::server::Reply;
use once_cell::sync::Lazy;
use predicates::prelude::*;
use std::{sync::Arc, time::Duration};
use test_log::test;
use tokio::sync::mpsc;
static TEMP_SCRIPT_DIR: Lazy<assert_fs::TempDir> =
@ -583,12 +579,21 @@ mod tests {
DistantCtx<ConnectionState>,
mpsc::Receiver<DistantResponseData>,
) {
let api = LocalDistantApi::initialize(Default::default()).unwrap();
let api = LocalDistantApi::initialize().unwrap();
let (reply, rx) = make_reply(buffer);
let connection_id = rand::random();
let mut local_data = ConnectionState::default();
DistantApi::on_accept(&api, &mut local_data).await;
DistantApi::on_accept(
&api,
ConnectionCtx {
connection_id,
local_data: &mut local_data,
},
)
.await
.unwrap();
let ctx = DistantCtx {
connection_id: rand::random(),
connection_id,
reply,
local_data: Arc::new(local_data),
};
@ -605,7 +610,7 @@ mod tests {
(Box::new(tx), rx)
}
#[tokio::test]
#[test(tokio::test)]
async fn read_file_should_fail_if_file_missing() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -614,7 +619,7 @@ mod tests {
let _ = api.read_file(ctx, path).await.unwrap_err();
}
#[tokio::test]
#[test(tokio::test)]
async fn read_file_should_send_blob_with_file_contents() {
let (api, ctx, _rx) = setup(1).await;
@ -626,7 +631,7 @@ mod tests {
assert_eq!(bytes, b"some file contents");
}
#[tokio::test]
#[test(tokio::test)]
async fn read_file_text_should_send_error_if_fails_to_read_file() {
let (api, ctx, _rx) = setup(1).await;
@ -636,7 +641,7 @@ mod tests {
let _ = api.read_file_text(ctx, path).await.unwrap_err();
}
#[tokio::test]
#[test(tokio::test)]
async fn read_file_text_should_send_text_with_file_contents() {
let (api, ctx, _rx) = setup(1).await;
@ -651,7 +656,7 @@ mod tests {
assert_eq!(text, "some file contents");
}
#[tokio::test]
#[test(tokio::test)]
async fn write_file_should_send_error_if_fails_to_write_file() {
let (api, ctx, _rx) = setup(1).await;
@ -669,7 +674,7 @@ mod tests {
file.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn write_file_should_send_ok_when_successful() {
let (api, ctx, _rx) = setup(1).await;
@ -687,7 +692,7 @@ mod tests {
file.assert("some text");
}
#[tokio::test]
#[test(tokio::test)]
async fn write_file_text_should_send_error_if_fails_to_write_file() {
let (api, ctx, _rx) = setup(1).await;
@ -704,7 +709,7 @@ mod tests {
file.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn write_file_text_should_send_ok_when_successful() {
let (api, ctx, _rx) = setup(1).await;
@ -722,7 +727,7 @@ mod tests {
file.assert("some text");
}
#[tokio::test]
#[test(tokio::test)]
async fn append_file_should_send_error_if_fails_to_create_file() {
let (api, ctx, _rx) = setup(1).await;
@ -743,7 +748,7 @@ mod tests {
file.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn append_file_should_create_file_if_missing() {
let (api, ctx, _rx) = setup(1).await;
@ -767,7 +772,7 @@ mod tests {
file.assert("some extra contents");
}
#[tokio::test]
#[test(tokio::test)]
async fn append_file_should_send_ok_when_successful() {
let (api, ctx, _rx) = setup(1).await;
@ -791,7 +796,7 @@ mod tests {
file.assert("some file contentssome extra contents");
}
#[tokio::test]
#[test(tokio::test)]
async fn append_file_text_should_send_error_if_fails_to_create_file() {
let (api, ctx, _rx) = setup(1).await;
@ -813,7 +818,7 @@ mod tests {
file.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn append_file_text_should_create_file_if_missing() {
let (api, ctx, _rx) = setup(1).await;
@ -837,7 +842,7 @@ mod tests {
file.assert("some extra contents");
}
#[tokio::test]
#[test(tokio::test)]
async fn append_file_text_should_send_ok_when_successful() {
let (api, ctx, _rx) = setup(1).await;
@ -861,7 +866,7 @@ mod tests {
file.assert("some file contentssome extra contents");
}
#[tokio::test]
#[test(tokio::test)]
async fn dir_read_should_send_error_if_directory_does_not_exist() {
let (api, ctx, _rx) = setup(1).await;
@ -902,7 +907,7 @@ mod tests {
root_dir
}
#[tokio::test]
#[test(tokio::test)]
async fn dir_read_should_support_depth_limits() {
let (api, ctx, _rx) = setup(1).await;
@ -936,7 +941,7 @@ mod tests {
assert_eq!(entries[2].depth, 1);
}
#[tokio::test]
#[test(tokio::test)]
async fn dir_read_should_support_unlimited_depth_using_zero() {
let (api, ctx, _rx) = setup(1).await;
@ -974,7 +979,7 @@ mod tests {
assert_eq!(entries[3].depth, 2);
}
#[tokio::test]
#[test(tokio::test)]
async fn dir_read_should_support_including_directory_in_returned_entries() {
let (api, ctx, _rx) = setup(1).await;
@ -1013,7 +1018,7 @@ mod tests {
assert_eq!(entries[3].depth, 1);
}
#[tokio::test]
#[test(tokio::test)]
async fn dir_read_should_support_returning_absolute_paths() {
let (api, ctx, _rx) = setup(1).await;
@ -1048,7 +1053,7 @@ mod tests {
assert_eq!(entries[2].depth, 1);
}
#[tokio::test]
#[test(tokio::test)]
async fn dir_read_should_support_returning_canonicalized_paths() {
let (api, ctx, _rx) = setup(1).await;
@ -1083,7 +1088,7 @@ mod tests {
assert_eq!(entries[2].depth, 1);
}
#[tokio::test]
#[test(tokio::test)]
async fn create_dir_should_send_error_if_fails() {
let (api, ctx, _rx) = setup(1).await;
@ -1101,7 +1106,7 @@ mod tests {
assert!(!path.exists(), "Path unexpectedly exists");
}
#[tokio::test]
#[test(tokio::test)]
async fn create_dir_should_send_ok_when_successful() {
let (api, ctx, _rx) = setup(1).await;
let root_dir = setup_dir().await;
@ -1115,7 +1120,7 @@ mod tests {
assert!(path.exists(), "Directory not created");
}
#[tokio::test]
#[test(tokio::test)]
async fn create_dir_should_support_creating_multiple_dir_components() {
let (api, ctx, _rx) = setup(1).await;
let root_dir = setup_dir().await;
@ -1129,7 +1134,7 @@ mod tests {
assert!(path.exists(), "Directory not created");
}
#[tokio::test]
#[test(tokio::test)]
async fn remove_should_send_error_on_failure() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1144,7 +1149,7 @@ mod tests {
file.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn remove_should_support_deleting_a_directory() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1159,7 +1164,7 @@ mod tests {
dir.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn remove_should_delete_nonempty_directory_if_force_is_true() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1175,7 +1180,7 @@ mod tests {
dir.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn remove_should_support_deleting_a_single_file() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1190,7 +1195,7 @@ mod tests {
file.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn copy_should_send_error_on_failure() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1206,7 +1211,7 @@ mod tests {
dst.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn copy_should_support_copying_an_entire_directory() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1230,7 +1235,7 @@ mod tests {
dst_file.assert(predicate::path::eq_file(src_file.path()));
}
#[tokio::test]
#[test(tokio::test)]
async fn copy_should_support_copying_an_empty_directory() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1247,7 +1252,7 @@ mod tests {
dst.assert(predicate::path::is_dir());
}
#[tokio::test]
#[test(tokio::test)]
async fn copy_should_support_copying_a_directory_that_only_contains_directories() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1271,7 +1276,7 @@ mod tests {
dst_dir.assert(predicate::path::is_dir().name("dst/dir"));
}
#[tokio::test]
#[test(tokio::test)]
async fn copy_should_support_copying_a_single_file() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1288,7 +1293,7 @@ mod tests {
dst.assert(predicate::path::eq_file(src.path()));
}
#[tokio::test]
#[test(tokio::test)]
async fn rename_should_fail_if_path_missing() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1304,7 +1309,7 @@ mod tests {
dst.assert(predicate::path::missing());
}
#[tokio::test]
#[test(tokio::test)]
async fn rename_should_support_renaming_an_entire_directory() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1328,7 +1333,7 @@ mod tests {
dst_file.assert("some contents");
}
#[tokio::test]
#[test(tokio::test)]
async fn rename_should_support_renaming_a_single_file() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1375,7 +1380,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn watch_should_support_watching_a_single_file() {
// NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc.
let (api, ctx, mut rx) = setup(100).await;
@ -1408,7 +1413,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn watch_should_support_watching_a_directory_recursively() {
// NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc.
let (api, ctx, mut rx) = setup(100).await;
@ -1485,7 +1490,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn watch_should_report_changes_using_the_ctx_replies() {
// NOTE: Supporting multiple replies being sent back as part of creating, modifying, etc.
let (api, ctx_1, mut rx_1) = setup(100).await;
@ -1558,7 +1563,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn exists_should_send_true_if_path_exists() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1569,7 +1574,7 @@ mod tests {
assert!(exists, "Expected exists to be true, but was false");
}
#[tokio::test]
#[test(tokio::test)]
async fn exists_should_send_false_if_path_does_not_exist() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1579,7 +1584,7 @@ mod tests {
assert!(!exists, "Expected exists to be false, but was true");
}
#[tokio::test]
#[test(tokio::test)]
async fn metadata_should_send_error_on_failure() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1596,7 +1601,7 @@ mod tests {
.unwrap_err();
}
#[tokio::test]
#[test(tokio::test)]
async fn metadata_should_send_back_metadata_on_file_if_exists() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1630,7 +1635,7 @@ mod tests {
}
#[cfg(unix)]
#[tokio::test]
#[test(tokio::test)]
async fn metadata_should_include_unix_specific_metadata_on_unix_platform() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1660,7 +1665,7 @@ mod tests {
}
#[cfg(windows)]
#[tokio::test]
#[test(tokio::test)]
async fn metadata_should_include_windows_specific_metadata_on_windows_platform() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1689,7 +1694,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn metadata_should_send_back_metadata_on_dir_if_exists() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1721,7 +1726,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn metadata_should_send_back_metadata_on_symlink_if_exists() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1756,7 +1761,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn metadata_should_include_canonicalized_path_if_flag_specified() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1791,7 +1796,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified() {
let (api, ctx, _rx) = setup(1).await;
let temp = assert_fs::TempDir::new().unwrap();
@ -1826,7 +1831,7 @@ mod tests {
// NOTE: Ignoring on windows because it's using WSL which wants a Linux path
// with / but thinks it's on windows and is providing \
#[tokio::test]
#[test(tokio::test)]
#[cfg_attr(windows, ignore)]
async fn proc_spawn_should_send_error_on_failure() {
let (api, ctx, _rx) = setup(1).await;
@ -1846,7 +1851,7 @@ mod tests {
// NOTE: Ignoring on windows because it's using WSL which wants a Linux path
// with / but thinks it's on windows and is providing \
#[tokio::test]
#[test(tokio::test)]
#[cfg_attr(windows, ignore)]
async fn proc_spawn_should_return_id_of_spawned_process() {
let (api, ctx, _rx) = setup(1).await;
@ -1872,7 +1877,7 @@ mod tests {
// NOTE: Ignoring on windows because it's using WSL which wants a Linux path
// with / but thinks it's on windows and is providing \
#[tokio::test]
#[test(tokio::test)]
#[cfg_attr(windows, ignore)]
async fn proc_spawn_should_send_back_stdout_periodically_when_available() {
let (api, ctx, mut rx) = setup(1).await;
@ -1937,7 +1942,7 @@ mod tests {
// NOTE: Ignoring on windows because it's using WSL which wants a Linux path
// with / but thinks it's on windows and is providing \
#[tokio::test]
#[test(tokio::test)]
#[cfg_attr(windows, ignore)]
async fn proc_spawn_should_send_back_stderr_periodically_when_available() {
let (api, ctx, mut rx) = setup(1).await;
@ -2002,7 +2007,7 @@ mod tests {
// NOTE: Ignoring on windows because it's using WSL which wants a Linux path
// with / but thinks it's on windows and is providing \
#[tokio::test]
#[test(tokio::test)]
#[cfg_attr(windows, ignore)]
async fn proc_spawn_should_send_done_signal_when_completed() {
let (api, ctx, mut rx) = setup(1).await;
@ -2033,7 +2038,7 @@ mod tests {
// NOTE: Ignoring on windows because it's using WSL which wants a Linux path
// with / but thinks it's on windows and is providing \
#[tokio::test]
#[test(tokio::test)]
#[cfg_attr(windows, ignore)]
async fn proc_spawn_should_clear_process_from_state_when_killed() {
let (api, ctx_1, mut rx) = setup(1).await;
@ -2074,7 +2079,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn proc_kill_should_fail_if_given_non_existent_process() {
let (api, ctx, _rx) = setup(1).await;
@ -2082,7 +2087,7 @@ mod tests {
let _ = api.proc_kill(ctx, 0xDEADBEEF).await.unwrap_err();
}
#[tokio::test]
#[test(tokio::test)]
async fn proc_stdin_should_fail_if_given_non_existent_process() {
let (api, ctx, _rx) = setup(1).await;
@ -2095,7 +2100,7 @@ mod tests {
// NOTE: Ignoring on windows because it's using WSL which wants a Linux path
// with / but thinks it's on windows and is providing \
#[tokio::test]
#[test(tokio::test)]
#[cfg_attr(windows, ignore)]
async fn proc_stdin_should_send_stdin_to_process() {
let (api, ctx_1, mut rx) = setup(1).await;
@ -2141,7 +2146,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn system_info_should_return_system_info_based_on_binary() {
let (api, ctx, _rx) = setup(1).await;

@ -3,7 +3,7 @@ use super::{
ProcessPty, PtySize, WaitRx,
};
use crate::{
constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS},
constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_DURATION},
data::Environment,
};
use log::*;
@ -150,8 +150,7 @@ impl PtyProcess {
break;
}
_ => {
tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS))
.await;
tokio::time::sleep(READ_PAUSE_DURATION).await;
continue;
}
}

@ -1,4 +1,4 @@
use crate::constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_MILLIS};
use crate::constants::{MAX_PIPE_CHUNK_SIZE, READ_PAUSE_DURATION};
use std::io;
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
@ -34,13 +34,13 @@ where
// Pause to allow buffer to fill up a little bit, avoiding
// spamming with a lot of smaller responses
tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS)).await;
tokio::time::sleep(READ_PAUSE_DURATION).await;
}
Ok(_) => return Ok(()),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
// Pause to allow buffer to fill up a little bit, avoiding
// spamming with a lot of smaller responses
tokio::time::sleep(tokio::time::Duration::from_millis(READ_PAUSE_MILLIS)).await;
tokio::time::sleep(READ_PAUSE_DURATION).await;
}
Err(x) => return Err(x),
}

@ -1,7 +1,5 @@
use crate::{
data::{ProcessId, SearchId},
ConnectionId,
};
use crate::data::{ProcessId, SearchId};
use distant_net::common::ConnectionId;
use std::{io, path::PathBuf};
mod process;

@ -1,5 +1,5 @@
use crate::data::{DistantResponseData, Environment, ProcessId, PtySize};
use distant_net::Reply;
use distant_net::server::Reply;
use std::{collections::HashMap, io, ops::Deref, path::PathBuf};
use tokio::{
sync::{mpsc, oneshot},

@ -4,7 +4,7 @@ use crate::{
},
data::{DistantResponseData, Environment, ProcessId, PtySize},
};
use distant_net::Reply;
use distant_net::server::Reply;
use log::*;
use std::{future::Future, io, path::PathBuf};
use tokio::task::JoinHandle;
@ -174,12 +174,9 @@ async fn stdout_task(
loop {
match stdout.recv().await {
Ok(Some(data)) => {
if let Err(x) = reply
reply
.send(DistantResponseData::ProcStdout { id, data })
.await
{
return Err(x);
}
.await?;
}
Ok(None) => return Ok(()),
Err(x) => return Err(x),
@ -195,12 +192,9 @@ async fn stderr_task(
loop {
match stderr.recv().await {
Ok(Some(data)) => {
if let Err(x) = reply
reply
.send(DistantResponseData::ProcStderr { id, data })
.await
{
return Err(x);
}
.await?;
}
Ok(None) => return Ok(()),
Err(x) => return Err(x),

@ -3,7 +3,7 @@ use crate::data::{
SearchQueryMatchData, SearchQueryOptions, SearchQueryPathMatch, SearchQuerySubmatch,
SearchQueryTarget,
};
use distant_net::Reply;
use distant_net::server::Reply;
use grep::{
matcher::Matcher,
regex::{RegexMatcher, RegexMatcherBuilder},
@ -764,6 +764,7 @@ mod tests {
use crate::data::{FileType, SearchQueryCondition, SearchQueryMatchData};
use assert_fs::prelude::*;
use std::path::PathBuf;
use test_log::test;
fn make_path(path: &str) -> PathBuf {
use std::path::MAIN_SEPARATOR;
@ -791,7 +792,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn should_send_event_when_query_finished() {
let root = setup_dir(Vec::new());
@ -816,7 +817,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_send_all_matches_at_once_by_default() {
let root = setup_dir(vec![
("path/to/file1.txt", ""),
@ -893,7 +894,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_support_targeting_paths() {
let root = setup_dir(vec![
("path/to/file1.txt", ""),
@ -971,7 +972,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_support_targeting_contents() {
let root = setup_dir(vec![
("path/to/file1.txt", "some\nlines of text in\na\nfile"),
@ -1047,7 +1048,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_support_multiple_submatches() {
let root = setup_dir(vec![("path/to/file.txt", "aa ab ac\nba bb bc\nca cb cc")]);
@ -1139,7 +1140,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_send_paginated_results_if_specified() {
let root = setup_dir(vec![
("path/to/file1.txt", "some\nlines of text in\na\nfile"),
@ -1235,7 +1236,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_send_maximum_of_limit_results_if_specified() {
let root = setup_dir(vec![
("path/to/file1.txt", "some\nlines of text in\na\nfile"),
@ -1272,7 +1273,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_send_maximum_of_limit_results_with_pagination_if_specified() {
let root = setup_dir(vec![
("path/to/file1.txt", "some\nlines of text in\na\nfile"),
@ -1313,7 +1314,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_traverse_no_deeper_than_max_depth_if_specified() {
let root = setup_dir(vec![
("path/to/file1.txt", ""),
@ -1409,7 +1410,7 @@ mod tests {
.await;
}
#[tokio::test]
#[test(tokio::test)]
async fn should_filter_searched_paths_to_only_those_that_match_include_regex() {
let root = setup_dir(vec![
("path/to/file1.txt", "some\nlines of text in\na\nfile"),
@ -1464,7 +1465,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_filter_searched_paths_to_only_those_that_do_not_match_exclude_regex() {
let root = setup_dir(vec![
("path/to/file1.txt", "some\nlines of text in\na\nfile"),
@ -1532,7 +1533,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_return_binary_match_data_if_match_is_not_utf8_but_path_is_explicit() {
let root = assert_fs::TempDir::new().unwrap();
let bin_file = root.child(make_path("file.bin"));
@ -1587,7 +1588,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_not_return_binary_match_data_if_match_is_not_utf8_and_not_explicit_path() {
let root = assert_fs::TempDir::new().unwrap();
let bin_file = root.child(make_path("file.bin"));
@ -1621,7 +1622,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_filter_searched_paths_to_only_those_are_an_allowed_file_type() {
let root = assert_fs::TempDir::new().unwrap();
let file = root.child(make_path("file"));
@ -1708,7 +1709,7 @@ mod tests {
.await;
}
#[tokio::test]
#[test(tokio::test)]
async fn should_follow_not_symbolic_links_if_specified_in_options() {
let root = assert_fs::TempDir::new().unwrap();
@ -1766,7 +1767,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_follow_symbolic_links_if_specified_in_options() {
let root = assert_fs::TempDir::new().unwrap();
@ -1825,7 +1826,7 @@ mod tests {
assert_eq!(rx.recv().await, None);
}
#[tokio::test]
#[test(tokio::test)]
async fn should_support_being_supplied_more_than_one_path() {
let root = setup_dir(vec![
("path/to/file1.txt", "some\nlines of text in\na\nfile"),

@ -1,4 +1,5 @@
use crate::{constants::SERVER_WATCHER_CAPACITY, data::ChangeKind, ConnectionId};
use crate::{constants::SERVER_WATCHER_CAPACITY, data::ChangeKind};
use distant_net::common::ConnectionId;
use log::*;
use notify::{
Config as WatcherConfig, Error as WatcherError, ErrorKind as WatcherErrorKind,
@ -41,26 +42,12 @@ impl WatcherState {
// with a large volume of watch requests
let (tx, rx) = mpsc::channel(SERVER_WATCHER_CAPACITY);
macro_rules! configure_and_spawn {
macro_rules! spawn_watcher {
($watcher:ident) => {{
// Attempt to configure watcher, but don't fail if these configurations fail
match $watcher.configure(WatcherConfig::PreciseEvents(true)) {
Ok(true) => debug!("Watcher configured for precise events"),
Ok(false) => debug!("Watcher not configured for precise events",),
Err(x) => error!("Watcher configuration for precise events failed: {}", x),
}
// Attempt to configure watcher, but don't fail if these configurations fail
match $watcher.configure(WatcherConfig::NoticeEvents(true)) {
Ok(true) => debug!("Watcher configured for notice events"),
Ok(false) => debug!("Watcher not configured for notice events",),
Err(x) => error!("Watcher configuration for notice events failed: {}", x),
}
Ok(Self {
Self {
channel: WatcherChannel { tx },
task: tokio::spawn(watcher_task($watcher, rx)),
})
}
}};
}
@ -91,7 +78,7 @@ impl WatcherState {
};
match result {
Ok(mut watcher) => configure_and_spawn!(watcher),
Ok(watcher) => Ok(spawn_watcher!(watcher)),
Err(x) => match x.kind {
// notify-rs has a bug on Mac M1 with Docker and Linux, so we detect that error
// and fall back to the poll watcher if this occurs
@ -99,9 +86,9 @@ impl WatcherState {
// https://github.com/notify-rs/notify/issues/423
WatcherErrorKind::Io(x) if x.raw_os_error() == Some(38) => {
warn!("Recommended watcher is unsupported! Falling back to polling watcher!");
let mut watcher = PollWatcher::new(event_handler!(tx))
let watcher = PollWatcher::new(event_handler!(tx), WatcherConfig::default())
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
configure_and_spawn!(watcher)
Ok(spawn_watcher!(watcher))
}
_ => Err(io::Error::new(io::ErrorKind::Other, x)),
},

@ -1,8 +1,6 @@
use crate::{
data::{Change, ChangeKind, ChangeKindSet, DistantResponseData, Error},
ConnectionId,
};
use distant_net::Reply;
use crate::data::{Change, ChangeKind, ChangeKindSet, DistantResponseData, Error};
use distant_net::common::ConnectionId;
use distant_net::server::Reply;
use std::{
fmt,
hash::{Hash, Hasher},

@ -1,5 +1,5 @@
use crate::{api::DistantMsg, data::DistantResponseData};
use distant_net::Reply;
use distant_net::server::Reply;
use std::{future::Future, io, pin::Pin};
/// Wrapper around a reply that can be batch or single, converting

@ -1,5 +1,5 @@
use crate::{DistantMsg, DistantRequestData, DistantResponseData};
use distant_net::{Channel, Client};
use distant_net::{client::Channel, Client};
mod ext;
mod lsp;

@ -9,7 +9,7 @@ use crate::{
},
DistantMsg,
};
use distant_net::{Channel, Request};
use distant_net::{client::Channel, common::Request};
use std::{future::Future, io, path::PathBuf, pin::Pin};
pub type AsyncReturn<'a, T, E = io::Error> =

@ -411,33 +411,33 @@ mod tests {
use super::*;
use crate::data::{DistantRequestData, DistantResponseData};
use distant_net::{
Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Request, Response,
TypedAsyncRead, TypedAsyncWrite,
common::{FramedTransport, InmemoryTransport, Request, Response},
Client, ReconnectStrategy,
};
use std::{future::Future, time::Duration};
use test_log::test;
/// Timeout used with timeout function
const TIMEOUT: Duration = Duration::from_millis(50);
// Configures an lsp process with a means to send & receive data from outside
async fn spawn_lsp_process() -> (
FramedTransport<InmemoryTransport, PlainCodec>,
RemoteLspProcess,
) {
async fn spawn_lsp_process() -> (FramedTransport<InmemoryTransport>, RemoteLspProcess) {
let (mut t1, t2) = FramedTransport::pair(100);
let (writer, reader) = t2.into_split();
let session = Client::new(writer, reader).unwrap();
let spawn_task = tokio::spawn(async move {
let client = Client::spawn_inmemory(t2, ReconnectStrategy::Fail);
let spawn_task = tokio::spawn({
let channel = client.clone_channel();
async move {
RemoteLspCommand::new()
.spawn(session.clone_channel(), String::from("cmd arg"))
.spawn(channel, String::from("cmd arg"))
.await
}
});
// Wait until we get the request from the session
let req: Request<DistantRequestData> = t1.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = t1.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
t1.write(Response::new(
t1.write_frame_for(&Response::new(
req.id,
DistantResponseData::ProcSpawned { id: rand::random() },
))
@ -471,7 +471,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn stdin_write_should_only_send_out_complete_lsp_messages() {
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -486,7 +486,7 @@ mod tests {
.unwrap();
// Validate that the outgoing req is a complete LSP message
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
@ -501,7 +501,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn stdin_write_should_support_buffering_output_until_a_complete_lsp_message_is_composed()
{
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -520,7 +520,7 @@ mod tests {
tokio::task::yield_now().await;
let result = timeout(
TIMEOUT,
TypedAsyncRead::<Request<DistantRequestData>>::read(&mut transport),
transport.read_frame_as::<Request<DistantRequestData>>(),
)
.await;
assert!(result.is_err(), "Unexpectedly got data: {:?}", result);
@ -529,7 +529,7 @@ mod tests {
proc.stdin.as_mut().unwrap().write(msg_b).await.unwrap();
// Validate that the outgoing req is a complete LSP message
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
@ -544,7 +544,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn stdin_write_should_only_consume_a_complete_lsp_message_even_if_more_is_written() {
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -564,7 +564,7 @@ mod tests {
.unwrap();
// Validate that the outgoing req is a complete LSP message
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
@ -586,7 +586,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stdin_write_should_support_sending_out_multiple_lsp_messages_if_all_received_at_once()
{
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -613,7 +613,7 @@ mod tests {
.unwrap();
// Validate that the first outgoing req is a complete LSP message matching first
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
@ -628,7 +628,7 @@ mod tests {
}
// Validate that the second outgoing req is a complete LSP message matching second
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
assert_eq!(
@ -643,7 +643,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn stdin_write_should_convert_content_with_distant_scheme_to_file_scheme() {
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -658,7 +658,7 @@ mod tests {
.unwrap();
// Validate that the outgoing req is a complete LSP message
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
match req.payload {
DistantRequestData::ProcStdin { data, .. } => {
// Verify the contents AND headers are as expected; in this case,
@ -676,13 +676,13 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn stdout_read_should_yield_lsp_messages_as_strings() {
let (mut transport, mut proc) = spawn_lsp_process().await;
// Send complete LSP message as stdout to process
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
@ -706,7 +706,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stdout_read_should_only_yield_complete_lsp_messages() {
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -718,7 +718,7 @@ mod tests {
// Send half of LSP message over stdout
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
@ -736,7 +736,7 @@ mod tests {
// Send other half of LSP message over stdout
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
@ -757,7 +757,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stdout_read_should_only_consume_a_complete_lsp_message_even_if_more_output_is_available(
) {
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -770,7 +770,7 @@ mod tests {
// Send complete LSP message as stdout to process
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
@ -798,7 +798,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stdout_read_should_support_yielding_multiple_lsp_messages_if_all_received_at_once() {
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -813,7 +813,7 @@ mod tests {
// Send complete LSP message as stdout to process
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
@ -849,13 +849,13 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stdout_read_should_convert_content_with_file_scheme_to_distant_scheme() {
let (mut transport, mut proc) = spawn_lsp_process().await;
// Send complete LSP message as stdout to process
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStdout {
id: proc.id(),
@ -879,13 +879,13 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stderr_read_should_yield_lsp_messages_as_strings() {
let (mut transport, mut proc) = spawn_lsp_process().await;
// Send complete LSP message as stderr to process
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
@ -909,7 +909,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stderr_read_should_only_yield_complete_lsp_messages() {
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -921,7 +921,7 @@ mod tests {
// Send half of LSP message over stderr
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
@ -939,7 +939,7 @@ mod tests {
// Send other half of LSP message over stderr
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
@ -960,7 +960,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stderr_read_should_only_consume_a_complete_lsp_message_even_if_more_errput_is_available(
) {
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -973,7 +973,7 @@ mod tests {
// Send complete LSP message as stderr to process
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
@ -1001,7 +1001,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stderr_read_should_support_yielding_multiple_lsp_messages_if_all_received_at_once() {
let (mut transport, mut proc) = spawn_lsp_process().await;
@ -1016,7 +1016,7 @@ mod tests {
// Send complete LSP message as stderr to process
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),
@ -1052,13 +1052,13 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stderr_read_should_convert_content_with_file_scheme_to_distant_scheme() {
let (mut transport, mut proc) = spawn_lsp_process().await;
// Send complete LSP message as stderr to process
transport
.write(Response::new(
.write_frame_for(&Response::new(
proc.origin_id().to_string(),
DistantResponseData::ProcStderr {
id: proc.id(),

@ -310,7 +310,7 @@ fn swap_prefix(obj: &mut Map<String, Value>, old: &str, new: &str) {
let check = |s: &String| s.starts_with(old);
let mut mutate = |s: &mut String| {
if let Some(pos) = s.find(old) {
s.replace_range(pos..old.len(), new);
s.replace_range(pos..pos + old.len(), new);
}
};
@ -396,6 +396,7 @@ impl FromStr for LspContent {
#[cfg(test)]
mod tests {
use super::*;
use test_log::test;
macro_rules! make_obj {
($($tail:tt)*) => {

@ -4,7 +4,10 @@ use crate::{
data::{Cmd, DistantRequestData, DistantResponseData, Environment, ProcessId, PtySize},
DistantMsg,
};
use distant_net::{Mailbox, Request, Response};
use distant_net::{
client::Mailbox,
common::{Request, Response},
};
use log::*;
use std::{path::PathBuf, sync::Arc};
use tokio::{
@ -609,21 +612,18 @@ mod tests {
data::{Error, ErrorKind},
};
use distant_net::{
Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response,
TypedAsyncRead, TypedAsyncWrite,
common::{FramedTransport, InmemoryTransport, Response},
Client, ReconnectStrategy,
};
use std::time::Duration;
use test_log::test;
fn make_session() -> (
FramedTransport<InmemoryTransport, PlainCodec>,
DistantClient,
) {
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::pair(100);
let (writer, reader) = t2.into_split();
(t1, Client::new(writer, reader).unwrap())
(t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail))
}
#[tokio::test]
#[test(tokio::test)]
async fn spawn_should_return_invalid_data_if_received_batch_response() {
let (mut transport, session) = make_session();
@ -636,11 +636,12 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Batch(vec![DistantResponseData::ProcSpawned { id: 1 }]),
))
@ -654,7 +655,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn spawn_should_return_invalid_data_if_did_not_get_a_indicator_that_process_started() {
let (mut transport, session) = make_session();
@ -667,11 +668,12 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::Error(Error {
kind: ErrorKind::BrokenPipe,
@ -688,7 +690,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn kill_should_return_error_if_internal_tasks_already_completed() {
let (mut transport, session) = make_session();
@ -701,12 +703,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -726,7 +729,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn kill_should_send_proc_kill_request_and_then_cause_stdin_forwarding_to_close() {
let (mut transport, session) = make_session();
@ -739,12 +742,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -756,7 +760,8 @@ mod tests {
assert!(proc.kill().await.is_ok(), "Failed to send kill request");
// Verify the kill request was sent
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
match req.payload {
DistantMsg::Single(DistantRequestData::ProcKill { id: proc_id }) => {
assert_eq!(proc_id, id)
@ -777,7 +782,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn stdin_should_be_forwarded_from_receiver_field() {
let (mut transport, session) = make_session();
@ -790,12 +795,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -812,7 +818,8 @@ mod tests {
.unwrap();
// Verify that a request is made through the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
match req.payload {
DistantMsg::Single(DistantRequestData::ProcStdin { id, data }) => {
assert_eq!(id, 12345);
@ -822,7 +829,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn stdout_should_be_forwarded_to_receiver_field() {
let (mut transport, session) = make_session();
@ -835,12 +842,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -851,7 +859,7 @@ mod tests {
let mut proc = spawn_task.await.unwrap().unwrap();
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcStdout {
id,
@ -865,7 +873,7 @@ mod tests {
assert_eq!(out, b"some out");
}
#[tokio::test]
#[test(tokio::test)]
async fn stderr_should_be_forwarded_to_receiver_field() {
let (mut transport, session) = make_session();
@ -878,12 +886,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -894,7 +903,7 @@ mod tests {
let mut proc = spawn_task.await.unwrap().unwrap();
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcStderr {
id,
@ -908,7 +917,7 @@ mod tests {
assert_eq!(out, b"some err");
}
#[tokio::test]
#[test(tokio::test)]
async fn status_should_return_none_if_not_done() {
let (mut transport, session) = make_session();
@ -921,12 +930,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -940,7 +950,7 @@ mod tests {
assert_eq!(result, None, "Unexpectedly got proc status: {:?}", result);
}
#[tokio::test]
#[test(tokio::test)]
async fn status_should_return_false_for_success_if_internal_tasks_fail() {
let (mut transport, session) = make_session();
@ -953,12 +963,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -986,7 +997,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn status_should_return_process_status_when_done() {
let (mut transport, session) = make_session();
@ -999,12 +1010,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -1016,7 +1028,7 @@ mod tests {
// Send a process completion response to pass along exit status and conclude wait
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcDone {
id,
@ -1040,7 +1052,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn wait_should_return_error_if_internal_tasks_fail() {
let (mut transport, session) = make_session();
@ -1053,12 +1065,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -1075,7 +1088,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn wait_should_return_error_if_connection_terminates_before_receiving_done_response() {
let (mut transport, session) = make_session();
@ -1088,12 +1101,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -1117,7 +1131,7 @@ mod tests {
}
}
#[tokio::test]
#[test(tokio::test)]
async fn receiving_done_response_should_result_in_wait_returning_exit_information() {
let (mut transport, session) = make_session();
@ -1130,12 +1144,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -1148,7 +1163,7 @@ mod tests {
// Send a process completion response to pass along exit status and conclude wait
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcDone {
id,
@ -1169,7 +1184,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn receiving_done_response_should_result_in_output_returning_exit_information() {
let (mut transport, session) = make_session();
@ -1182,12 +1197,13 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantMsg<DistantRequestData>> = transport.read().await.unwrap().unwrap();
let req: Request<DistantMsg<DistantRequestData>> =
transport.read_frame_as().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantMsg::Single(DistantResponseData::ProcSpawned { id }),
))
@ -1200,7 +1216,7 @@ mod tests {
// Send some stdout
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantMsg::Single(DistantResponseData::ProcStdout {
id,
@ -1212,7 +1228,7 @@ mod tests {
// Send some stderr
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantMsg::Single(DistantResponseData::ProcStderr {
id,
@ -1224,7 +1240,7 @@ mod tests {
// Send a process completion response to pass along exit status and conclude wait
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantMsg::Single(DistantResponseData::ProcDone {
id,

@ -4,7 +4,7 @@ use crate::{
data::{DistantRequestData, DistantResponseData, SearchId, SearchQuery, SearchQueryMatch},
DistantMsg,
};
use distant_net::Request;
use distant_net::common::Request;
use log::*;
use std::{fmt, io};
use tokio::{sync::mpsc, task::JoinHandle};
@ -197,22 +197,19 @@ mod tests {
};
use crate::DistantClient;
use distant_net::{
Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response,
TypedAsyncRead, TypedAsyncWrite,
common::{FramedTransport, InmemoryTransport, Response},
Client, ReconnectStrategy,
};
use std::{path::PathBuf, sync::Arc};
use test_log::test;
use tokio::sync::Mutex;
fn make_session() -> (
FramedTransport<InmemoryTransport, PlainCodec>,
DistantClient,
) {
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::pair(100);
let (writer, reader) = t2.into_split();
(t1, Client::new(writer, reader).unwrap())
(t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail))
}
#[tokio::test]
#[test(tokio::test)]
async fn searcher_should_have_query_reflect_ongoing_query() {
let (mut transport, session) = make_session();
let test_query = SearchQuery {
@ -232,11 +229,11 @@ mod tests {
};
// Wait until we get the request from the session
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
// Send back an acknowledgement that a search was started
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantResponseData::SearchStarted { id: rand::random() },
))
@ -248,7 +245,7 @@ mod tests {
assert_eq!(searcher.query(), &test_query);
}
#[tokio::test]
#[test(tokio::test)]
async fn searcher_should_support_getting_next_match() {
let (mut transport, session) = make_session();
let test_query = SearchQuery {
@ -268,12 +265,12 @@ mod tests {
);
// Wait until we get the request from the session
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
// Send back an acknowledgement that a searcher was created
let id = rand::random::<SearchId>();
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantResponseData::SearchStarted { id },
))
@ -285,7 +282,7 @@ mod tests {
// Send some matches related to the file
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
vec![
DistantResponseData::SearchResults {
@ -366,7 +363,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn searcher_should_distinguish_match_events_and_only_receive_matches_for_itself() {
let (mut transport, session) = make_session();
@ -387,12 +384,12 @@ mod tests {
);
// Wait until we get the request from the session
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
// Send back an acknowledgement that a searcher was created
let id = rand::random();
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantResponseData::SearchStarted { id },
))
@ -404,7 +401,7 @@ mod tests {
// Send a match from the appropriate origin
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantResponseData::SearchResults {
id,
@ -423,7 +420,7 @@ mod tests {
// Send a chanmatchge from a different origin
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone() + "1",
DistantResponseData::SearchResults {
id,
@ -442,7 +439,7 @@ mod tests {
// Send a chanmatchge from the appropriate origin
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantResponseData::SearchResults {
id,
@ -487,7 +484,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn searcher_should_stop_receiving_events_if_cancelled() {
let (mut transport, session) = make_session();
@ -508,12 +505,12 @@ mod tests {
);
// Wait until we get the request from the session
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
// Send back an acknowledgement that a watcher was created
let id = rand::random::<SearchId>();
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantResponseData::SearchStarted { id },
))
@ -522,7 +519,7 @@ mod tests {
// Send some matches from the appropriate origin
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantResponseData::SearchResults {
id,
@ -579,10 +576,10 @@ mod tests {
let searcher_2 = Arc::clone(&searcher);
let cancel_task = tokio::spawn(async move { searcher_2.lock().await.cancel().await });
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
transport
.write(Response::new(req.id.clone(), DistantResponseData::Ok))
.write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok))
.await
.unwrap();
@ -591,7 +588,7 @@ mod tests {
// Send a match that will get ignored
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantResponseData::SearchResults {
id,

@ -4,7 +4,7 @@ use crate::{
data::{Change, ChangeKindSet, DistantRequestData, DistantResponseData},
DistantMsg,
};
use distant_net::Request;
use distant_net::common::Request;
use log::*;
use std::{
fmt, io,
@ -185,22 +185,19 @@ mod tests {
use crate::data::ChangeKind;
use crate::DistantClient;
use distant_net::{
Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Response,
TypedAsyncRead, TypedAsyncWrite,
common::{FramedTransport, InmemoryTransport, Response},
Client, ReconnectStrategy,
};
use std::sync::Arc;
use test_log::test;
use tokio::sync::Mutex;
fn make_session() -> (
FramedTransport<InmemoryTransport, PlainCodec>,
DistantClient,
) {
fn make_session() -> (FramedTransport<InmemoryTransport>, DistantClient) {
let (t1, t2) = FramedTransport::pair(100);
let (writer, reader) = t2.into_split();
(t1, Client::new(writer, reader).unwrap())
(t1, Client::spawn_inmemory(t2, ReconnectStrategy::Fail))
}
#[tokio::test]
#[test(tokio::test)]
async fn watcher_should_have_path_reflect_watched_path() {
let (mut transport, session) = make_session();
let test_path = Path::new("/some/test/path");
@ -219,11 +216,11 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
// Send back an acknowledgement that a watcher was created
transport
.write(Response::new(req.id, DistantResponseData::Ok))
.write_frame_for(&Response::new(req.id, DistantResponseData::Ok))
.await
.unwrap();
@ -232,7 +229,7 @@ mod tests {
assert_eq!(watcher.path(), test_path);
}
#[tokio::test]
#[test(tokio::test)]
async fn watcher_should_support_getting_next_change() {
let (mut transport, session) = make_session();
let test_path = Path::new("/some/test/path");
@ -251,11 +248,11 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
// Send back an acknowledgement that a watcher was created
transport
.write(Response::new(req.id.clone(), DistantResponseData::Ok))
.write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok))
.await
.unwrap();
@ -264,7 +261,7 @@ mod tests {
// Send some changes related to the file
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
vec![
DistantResponseData::Changed(Change {
@ -300,7 +297,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn watcher_should_distinguish_change_events_and_only_receive_changes_for_itself() {
let (mut transport, session) = make_session();
let test_path = Path::new("/some/test/path");
@ -319,11 +316,11 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
// Send back an acknowledgement that a watcher was created
transport
.write(Response::new(req.id.clone(), DistantResponseData::Ok))
.write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok))
.await
.unwrap();
@ -332,7 +329,7 @@ mod tests {
// Send a change from the appropriate origin
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone(),
DistantResponseData::Changed(Change {
kind: ChangeKind::Access,
@ -344,7 +341,7 @@ mod tests {
// Send a change from a different origin
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id.clone() + "1",
DistantResponseData::Changed(Change {
kind: ChangeKind::Content,
@ -356,7 +353,7 @@ mod tests {
// Send a change from the appropriate origin
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantResponseData::Changed(Change {
kind: ChangeKind::Remove,
@ -386,7 +383,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn watcher_should_stop_receiving_events_if_unwatched() {
let (mut transport, session) = make_session();
let test_path = Path::new("/some/test/path");
@ -405,17 +402,17 @@ mod tests {
});
// Wait until we get the request from the session
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
// Send back an acknowledgement that a watcher was created
transport
.write(Response::new(req.id.clone(), DistantResponseData::Ok))
.write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok))
.await
.unwrap();
// Send some changes from the appropriate origin
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
vec![
DistantResponseData::Changed(Change {
@ -461,10 +458,10 @@ mod tests {
let watcher_2 = Arc::clone(&watcher);
let unwatch_task = tokio::spawn(async move { watcher_2.lock().await.unwatch().await });
let req: Request<DistantRequestData> = transport.read().await.unwrap().unwrap();
let req: Request<DistantRequestData> = transport.read_frame_as().await.unwrap().unwrap();
transport
.write(Response::new(req.id.clone(), DistantResponseData::Ok))
.write_frame_for(&Response::new(req.id.clone(), DistantResponseData::Ok))
.await
.unwrap();
@ -472,7 +469,7 @@ mod tests {
unwatch_task.await.unwrap().unwrap();
transport
.write(Response::new(
.write_frame_for(&Response::new(
req.id,
DistantResponseData::Changed(Change {
kind: ChangeKind::Unknown,

@ -1,3 +1,5 @@
use std::time::Duration;
/// Capacity associated stdin, stdout, and stderr pipes receiving data from remote server
pub const CLIENT_PIPE_CAPACITY: usize = 10000;
@ -18,4 +20,4 @@ pub const MAX_PIPE_CHUNK_SIZE: usize = 16384;
/// Duration in milliseconds to sleep between reading stdout/stderr chunks
/// to avoid sending many small messages to clients
pub const READ_PAUSE_MILLIS: u64 = 50;
pub const READ_PAUSE_DURATION: Duration = Duration::from_millis(1);

@ -1,8 +1,5 @@
use crate::{
serde_str::{deserialize_from_str, serialize_to_str},
Destination, Host,
};
use distant_net::SecretKey32;
use crate::serde_str::{deserialize_from_str, serialize_to_str};
use distant_net::common::{Destination, Host, SecretKey32};
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
use std::{convert::TryFrom, fmt, io, str::FromStr};
@ -154,6 +151,7 @@ mod tests {
use super::*;
use once_cell::sync::Lazy;
use std::net::{Ipv4Addr, Ipv6Addr};
use test_log::test;
const HOST: &str = "testhost";
const PORT: u16 = 12345;

@ -24,9 +24,6 @@ pub use error::*;
mod filesystem;
pub use filesystem::*;
mod map;
pub use map::Map;
mod metadata;
pub use metadata::*;
@ -46,7 +43,7 @@ pub(crate) use utils::*;
pub type ProcessId = u32;
/// Mapping of environment variables
pub type Environment = Map;
pub type Environment = distant_net::common::Map;
/// Type alias for a vec of bytes
///

@ -391,6 +391,7 @@ mod tests {
mod search_query_condition {
use super::*;
use test_log::test;
#[test]
fn to_regex_string_should_convert_to_appropriate_regex_and_escape_as_needed() {

@ -8,10 +8,7 @@ mod credentials;
pub use credentials::*;
pub mod data;
pub use data::{DistantMsg, DistantRequestData, DistantResponseData, Map};
mod manager;
pub use manager::*;
pub use data::{DistantMsg, DistantRequestData, DistantResponseData};
mod constants;
mod serde_str;

@ -1,783 +0,0 @@
use super::data::{
ConnectionId, ConnectionInfo, ConnectionList, Destination, ManagerCapabilities, ManagerRequest,
ManagerResponse,
};
use crate::{
DistantChannel, DistantClient, DistantMsg, DistantRequestData, DistantResponseData, Map,
};
use distant_net::{
router, Auth, AuthServer, Client, IntoSplit, MpscTransport, OneshotListener, Request, Response,
ServerExt, ServerRef, UntypedTransportRead, UntypedTransportWrite,
};
use log::*;
use std::{
collections::HashMap,
io,
ops::{Deref, DerefMut},
};
use tokio::task::JoinHandle;
mod config;
pub use config::*;
mod ext;
pub use ext::*;
router!(DistantManagerClientRouter {
auth_transport: Request<Auth> => Response<Auth>,
manager_transport: Response<ManagerResponse> => Request<ManagerRequest>,
});
/// Represents a client that can connect to a remote distant manager
pub struct DistantManagerClient {
auth: Box<dyn ServerRef>,
client: Client<ManagerRequest, ManagerResponse>,
distant_clients: HashMap<ConnectionId, ClientHandle>,
}
impl Drop for DistantManagerClient {
fn drop(&mut self) {
self.auth.abort();
self.client.abort();
}
}
/// Represents a raw channel between a manager client and some remote server
pub struct RawDistantChannel {
pub transport: MpscTransport<
Request<DistantMsg<DistantRequestData>>,
Response<DistantMsg<DistantResponseData>>,
>,
forward_task: JoinHandle<()>,
mailbox_task: JoinHandle<()>,
}
impl RawDistantChannel {
pub fn abort(&self) {
self.forward_task.abort();
self.mailbox_task.abort();
}
}
impl Deref for RawDistantChannel {
type Target = MpscTransport<
Request<DistantMsg<DistantRequestData>>,
Response<DistantMsg<DistantResponseData>>,
>;
fn deref(&self) -> &Self::Target {
&self.transport
}
}
impl DerefMut for RawDistantChannel {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.transport
}
}
struct ClientHandle {
client: DistantClient,
forward_task: JoinHandle<()>,
mailbox_task: JoinHandle<()>,
}
impl Drop for ClientHandle {
fn drop(&mut self) {
self.forward_task.abort();
self.mailbox_task.abort();
}
}
impl DistantManagerClient {
/// Initializes a client using the provided [`UntypedTransport`]
pub fn new<T>(config: DistantManagerClientConfig, transport: T) -> io::Result<Self>
where
T: IntoSplit + 'static,
T::Read: UntypedTransportRead + 'static,
T::Write: UntypedTransportWrite + 'static,
{
let DistantManagerClientRouter {
auth_transport,
manager_transport,
..
} = DistantManagerClientRouter::new(transport);
// Initialize our client with manager request/response transport
let (writer, reader) = manager_transport.into_split();
let client = Client::new(writer, reader)?;
// Initialize our auth handler with auth/auth transport
let auth = AuthServer {
on_challenge: config.on_challenge,
on_verify: config.on_verify,
on_info: config.on_info,
on_error: config.on_error,
}
.start(OneshotListener::from_value(auth_transport.into_split()))?;
Ok(Self {
auth,
client,
distant_clients: HashMap::new(),
})
}
/// Request that the manager launches a new server at the given `destination`
/// with `options` being passed for destination-specific details, returning the new
/// `destination` of the spawned server to connect to
pub async fn launch(
&mut self,
destination: impl Into<Destination>,
options: impl Into<Map>,
) -> io::Result<Destination> {
let destination = Box::new(destination.into());
let options = options.into();
trace!("launch({}, {})", destination, options);
let res = self
.client
.send(ManagerRequest::Launch {
destination,
options,
})
.await?;
match res.payload {
ManagerResponse::Launched { destination } => Ok(destination),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Request that the manager establishes a new connection at the given `destination`
/// with `options` being passed for destination-specific details
pub async fn connect(
&mut self,
destination: impl Into<Destination>,
options: impl Into<Map>,
) -> io::Result<ConnectionId> {
let destination = Box::new(destination.into());
let options = options.into();
trace!("connect({}, {})", destination, options);
let res = self
.client
.send(ManagerRequest::Connect {
destination,
options,
})
.await?;
match res.payload {
ManagerResponse::Connected { id } => Ok(id),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Establishes a channel with the server represented by the `connection_id`,
/// returning a [`DistantChannel`] acting as the connection
///
/// ### Note
///
/// Multiple calls to open a channel against the same connection will result in
/// clones of the same [`DistantChannel`] rather than establishing a duplicate
/// remote connection to the same server
pub async fn open_channel(
&mut self,
connection_id: ConnectionId,
) -> io::Result<DistantChannel> {
trace!("open_channel({})", connection_id);
if let Some(handle) = self.distant_clients.get(&connection_id) {
Ok(handle.client.clone_channel())
} else {
let RawDistantChannel {
transport,
forward_task,
mailbox_task,
} = self.open_raw_channel(connection_id).await?;
let (writer, reader) = transport.into_split();
let client = DistantClient::new(writer, reader)?;
let channel = client.clone_channel();
self.distant_clients.insert(
connection_id,
ClientHandle {
client,
forward_task,
mailbox_task,
},
);
Ok(channel)
}
}
/// Establishes a channel with the server represented by the `connection_id`,
/// returning a [`Transport`] acting as the connection
///
/// ### Note
///
/// Multiple calls to open a channel against the same connection will result in establishing a
/// duplicate remote connections to the same server, so take care when using this method
pub async fn open_raw_channel(
&mut self,
connection_id: ConnectionId,
) -> io::Result<RawDistantChannel> {
trace!("open_raw_channel({})", connection_id);
let mut mailbox = self
.client
.mail(ManagerRequest::OpenChannel { id: connection_id })
.await?;
// Wait for the first response, which should be channel confirmation
let channel_id = match mailbox.next().await {
Some(response) => match response.payload {
ManagerResponse::ChannelOpened { id } => Ok(id),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
},
None => Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"open_channel mailbox aborted",
)),
}?;
// Spawn reader and writer tasks to forward requests and replies
// using our opened channel
let (t1, t2) = MpscTransport::pair(1);
let (mut writer, mut reader) = t1.into_split();
let mailbox_task = tokio::spawn(async move {
use distant_net::TypedAsyncWrite;
while let Some(response) = mailbox.next().await {
match response.payload {
ManagerResponse::Channel { response, .. } => {
if let Err(x) = writer.write(response).await {
error!("[Conn {}] {}", connection_id, x);
}
}
ManagerResponse::ChannelClosed { .. } => break,
_ => continue,
}
}
});
let mut manager_channel = self.client.clone_channel();
let forward_task = tokio::spawn(async move {
use distant_net::TypedAsyncRead;
loop {
match reader.read().await {
Ok(Some(request)) => {
// NOTE: In this situation, we do not expect a response to this
// request (even if the server sends something back)
if let Err(x) = manager_channel
.fire(ManagerRequest::Channel {
id: channel_id,
request,
})
.await
{
error!("[Conn {}] {}", connection_id, x);
}
}
Ok(None) => break,
Err(x) => {
error!("[Conn {}] {}", connection_id, x);
continue;
}
}
}
});
Ok(RawDistantChannel {
transport: t2,
forward_task,
mailbox_task,
})
}
/// Retrieves a list of supported capabilities
pub async fn capabilities(&mut self) -> io::Result<ManagerCapabilities> {
trace!("capabilities()");
let res = self.client.send(ManagerRequest::Capabilities).await?;
match res.payload {
ManagerResponse::Capabilities { supported } => Ok(supported),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Retrieves information about a specific connection
pub async fn info(&mut self, id: ConnectionId) -> io::Result<ConnectionInfo> {
trace!("info({})", id);
let res = self.client.send(ManagerRequest::Info { id }).await?;
match res.payload {
ManagerResponse::Info(info) => Ok(info),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Kills the specified connection
pub async fn kill(&mut self, id: ConnectionId) -> io::Result<()> {
trace!("kill({})", id);
let res = self.client.send(ManagerRequest::Kill { id }).await?;
match res.payload {
ManagerResponse::Killed => Ok(()),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Retrieves a list of active connections
pub async fn list(&mut self) -> io::Result<ConnectionList> {
trace!("list()");
let res = self.client.send(ManagerRequest::List).await?;
match res.payload {
ManagerResponse::List(list) => Ok(list),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
/// Requests that the manager shuts down
pub async fn shutdown(&mut self) -> io::Result<()> {
trace!("shutdown()");
let res = self.client.send(ManagerRequest::Shutdown).await?;
match res.payload {
ManagerResponse::Shutdown => Ok(()),
ManagerResponse::Error(x) => Err(x.into()),
x => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Got unexpected response: {:?}", x),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::{Error, ErrorKind};
use distant_net::{
FramedTransport, InmemoryTransport, PlainCodec, UntypedTransportRead, UntypedTransportWrite,
};
fn setup() -> (
DistantManagerClient,
FramedTransport<InmemoryTransport, PlainCodec>,
) {
let (t1, t2) = FramedTransport::pair(100);
let client =
DistantManagerClient::new(DistantManagerClientConfig::with_empty_prompts(), t1)
.unwrap();
(client, t2)
}
#[inline]
fn test_error() -> Error {
Error {
kind: ErrorKind::Interrupted,
description: "test error".to_string(),
}
}
#[inline]
fn test_io_error() -> io::Error {
test_error().into()
}
#[tokio::test]
async fn connect_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client
.connect(
"scheme://host".parse::<Destination>().unwrap(),
"key=value".parse::<Map>().unwrap(),
)
.await
.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn connect_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
let err = client
.connect(
"scheme://host".parse::<Destination>().unwrap(),
"key=value".parse::<Map>().unwrap(),
)
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn connect_should_return_id_from_successful_response() {
let (mut client, mut transport) = setup();
let expected_id = 999;
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Connected { id: expected_id },
))
.await
.unwrap();
});
let id = client
.connect(
"scheme://host".parse::<Destination>().unwrap(),
"key=value".parse::<Map>().unwrap(),
)
.await
.unwrap();
assert_eq!(id, expected_id);
}
#[tokio::test]
async fn info_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client.info(123).await.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn info_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
let err = client.info(123).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn info_should_return_connection_info_from_successful_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
let info = ConnectionInfo {
id: 123,
destination: "scheme://host".parse::<Destination>().unwrap(),
options: "key=value".parse::<Map>().unwrap(),
};
transport
.write(Response::new(request.id, ManagerResponse::Info(info)))
.await
.unwrap();
});
let info = client.info(123).await.unwrap();
assert_eq!(info.id, 123);
assert_eq!(
info.destination,
"scheme://host".parse::<Destination>().unwrap()
);
assert_eq!(info.options, "key=value".parse::<Map>().unwrap());
}
#[tokio::test]
async fn list_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client.list().await.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn list_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
let err = client.list().await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn list_should_return_connection_list_from_successful_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
let mut list = ConnectionList::new();
list.insert(123, "scheme://host".parse::<Destination>().unwrap());
transport
.write(Response::new(request.id, ManagerResponse::List(list)))
.await
.unwrap();
});
let list = client.list().await.unwrap();
assert_eq!(list.len(), 1);
assert_eq!(
list.get(&123).expect("Connection list missing item"),
&"scheme://host".parse::<Destination>().unwrap()
);
}
#[tokio::test]
async fn kill_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client.kill(123).await.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn kill_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
let err = client.kill(123).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn kill_should_return_success_from_successful_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Killed))
.await
.unwrap();
});
client.kill(123).await.unwrap();
}
#[tokio::test]
async fn shutdown_should_report_error_if_receives_error_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Connected { id: 0 },
))
.await
.unwrap();
});
let err = client.shutdown().await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn shutdown_should_report_error_if_receives_unexpected_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(
request.id,
ManagerResponse::Error(test_error()),
))
.await
.unwrap();
});
let err = client.shutdown().await.unwrap_err();
assert_eq!(err.kind(), test_io_error().kind());
assert_eq!(err.to_string(), test_io_error().to_string());
}
#[tokio::test]
async fn shutdown_should_return_success_from_successful_response() {
let (mut client, mut transport) = setup();
tokio::spawn(async move {
let request = transport
.read::<Request<ManagerRequest>>()
.await
.unwrap()
.unwrap();
transport
.write(Response::new(request.id, ManagerResponse::Shutdown))
.await
.unwrap();
});
client.shutdown().await.unwrap();
}
}

@ -1,85 +0,0 @@
use distant_net::{AuthChallengeFn, AuthErrorFn, AuthInfoFn, AuthVerifyFn, AuthVerifyKind};
use log::*;
use std::io;
/// Configuration to use when creating a new [`DistantManagerClient`](super::DistantManagerClient)
pub struct DistantManagerClientConfig {
pub on_challenge: Box<AuthChallengeFn>,
pub on_verify: Box<AuthVerifyFn>,
pub on_info: Box<AuthInfoFn>,
pub on_error: Box<AuthErrorFn>,
}
impl DistantManagerClientConfig {
/// Creates a new config with prompts that return empty strings
pub fn with_empty_prompts() -> Self {
Self::with_prompts(|_| Ok("".to_string()), |_| Ok("".to_string()))
}
/// Creates a new config with two prompts
///
/// * `password_prompt` - used for prompting for a secret, and should not display what is typed
/// * `text_prompt` - used for general text, and is okay to display what is typed
pub fn with_prompts<PP, PT>(password_prompt: PP, text_prompt: PT) -> Self
where
PP: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
PT: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
{
Self {
on_challenge: Box::new(move |questions, _extra| {
trace!("[manager client] on_challenge({questions:?}, {_extra:?})");
let mut answers = Vec::new();
for question in questions.iter() {
// Contains all prompt lines including same line
let mut lines = question.text.split('\n').collect::<Vec<_>>();
// Line that is prompt on same line as answer
let line = lines.pop().unwrap();
// Go ahead and display all other lines
for line in lines.into_iter() {
eprintln!("{}", line);
}
// Get an answer from user input, or use a blank string as an answer
// if we fail to get input from the user
let answer = password_prompt(line).unwrap_or_default();
answers.push(answer);
}
answers
}),
on_verify: Box::new(move |kind, text| {
trace!("[manager client] on_verify({kind}, {text})");
match kind {
AuthVerifyKind::Host => {
eprintln!("{}", text);
match text_prompt("Enter [y/N]> ") {
Ok(answer) => {
trace!("Verify? Answer = '{answer}'");
matches!(answer.trim(), "y" | "Y" | "yes" | "YES")
}
Err(x) => {
error!("Failed verification: {x}");
false
}
}
}
x => {
error!("Unsupported verify kind: {x}");
false
}
}
}),
on_info: Box::new(|text| {
trace!("[manager client] on_info({text})");
println!("{}", text);
}),
on_error: Box::new(|kind, text| {
trace!("[manager client] on_error({kind}, {text})");
eprintln!("{}: {}", kind, text);
}),
}
}
}

@ -1,14 +0,0 @@
mod tcp;
pub use tcp::*;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
pub use unix::*;
#[cfg(windows)]
mod windows;
#[cfg(windows)]
pub use windows::*;

@ -1,50 +0,0 @@
use crate::{DistantManagerClient, DistantManagerClientConfig};
use async_trait::async_trait;
use distant_net::{Codec, FramedTransport, TcpTransport};
use std::{convert, net::SocketAddr};
use tokio::{io, time::Duration};
#[async_trait]
pub trait TcpDistantManagerClientExt {
/// Connect to a remote TCP server using the provided information
async fn connect<C>(
config: DistantManagerClientConfig,
addr: SocketAddr,
codec: C,
) -> io::Result<DistantManagerClient>
where
C: Codec + Send + 'static;
/// Connect to a remote TCP server, timing out after duration has passed
async fn connect_timeout<C>(
config: DistantManagerClientConfig,
addr: SocketAddr,
codec: C,
duration: Duration,
) -> io::Result<DistantManagerClient>
where
C: Codec + Send + 'static,
{
tokio::time::timeout(duration, Self::connect(config, addr, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
}
#[async_trait]
impl TcpDistantManagerClientExt for DistantManagerClient {
/// Connect to a remote TCP server using the provided information
async fn connect<C>(
config: DistantManagerClientConfig,
addr: SocketAddr,
codec: C,
) -> io::Result<DistantManagerClient>
where
C: Codec + Send + 'static,
{
let transport = TcpTransport::connect(addr).await?;
let transport = FramedTransport::new(transport, codec);
Self::new(config, transport)
}
}

@ -1,54 +0,0 @@
use crate::{DistantManagerClient, DistantManagerClientConfig};
use async_trait::async_trait;
use distant_net::{Codec, FramedTransport, UnixSocketTransport};
use std::{convert, path::Path};
use tokio::{io, time::Duration};
#[async_trait]
pub trait UnixSocketDistantManagerClientExt {
/// Connect to a proxy unix socket
async fn connect<P, C>(
config: DistantManagerClientConfig,
path: P,
codec: C,
) -> io::Result<DistantManagerClient>
where
P: AsRef<Path> + Send,
C: Codec + Send + 'static;
/// Connect to a proxy unix socket, timing out after duration has passed
async fn connect_timeout<P, C>(
config: DistantManagerClientConfig,
path: P,
codec: C,
duration: Duration,
) -> io::Result<DistantManagerClient>
where
P: AsRef<Path> + Send,
C: Codec + Send + 'static,
{
tokio::time::timeout(duration, Self::connect(config, path, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
}
#[async_trait]
impl UnixSocketDistantManagerClientExt for DistantManagerClient {
/// Connect to a proxy unix socket
async fn connect<P, C>(
config: DistantManagerClientConfig,
path: P,
codec: C,
) -> io::Result<DistantManagerClient>
where
P: AsRef<Path> + Send,
C: Codec + Send + 'static,
{
let p = path.as_ref();
let transport = UnixSocketTransport::connect(p).await?;
let transport = FramedTransport::new(transport, codec);
Ok(DistantManagerClient::new(config, transport)?)
}
}

@ -1,91 +0,0 @@
use crate::{DistantManagerClient, DistantManagerClientConfig};
use async_trait::async_trait;
use distant_net::{Codec, FramedTransport, WindowsPipeTransport};
use std::{
convert,
ffi::{OsStr, OsString},
};
use tokio::{io, time::Duration};
#[async_trait]
pub trait WindowsPipeDistantManagerClientExt {
/// Connect to a server listening on a Windows pipe at the specified address
/// using the given codec
async fn connect<A, C>(
config: DistantManagerClientConfig,
addr: A,
codec: C,
) -> io::Result<DistantManagerClient>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + 'static;
/// Connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}` using the given codec
async fn connect_local<N, C>(
config: DistantManagerClientConfig,
name: N,
codec: C,
) -> io::Result<DistantManagerClient>
where
N: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::connect(config, addr, codec).await
}
/// Connect to a server listening on a Windows pipe at the specified address
/// using the given codec, timing out after duration has passed
async fn connect_timeout<A, C>(
config: DistantManagerClientConfig,
addr: A,
codec: C,
duration: Duration,
) -> io::Result<DistantManagerClient>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
tokio::time::timeout(duration, Self::connect(config, addr, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
/// Connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed
async fn connect_local_timeout<N, C>(
config: DistantManagerClientConfig,
name: N,
codec: C,
duration: Duration,
) -> io::Result<DistantManagerClient>
where
N: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::connect_timeout(config, addr, codec, duration).await
}
}
#[async_trait]
impl WindowsPipeDistantManagerClientExt for DistantManagerClient {
async fn connect<A, C>(
config: DistantManagerClientConfig,
addr: A,
codec: C,
) -> io::Result<DistantManagerClient>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
let a = addr.as_ref();
let transport = WindowsPipeTransport::connect(a).await?;
let transport = FramedTransport::new(transport, codec);
Ok(DistantManagerClient::new(config, transport)?)
}
}

@ -1,5 +0,0 @@
/// Id associated with an active connection
pub type ConnectionId = u64;
/// Id associated with an open channel
pub type ChannelId = u64;

@ -1,719 +0,0 @@
use crate::{
ChannelId, ConnectionId, ConnectionInfo, ConnectionList, Destination, ManagerCapabilities,
ManagerRequest, ManagerResponse, Map,
};
use async_trait::async_trait;
use distant_net::{
router, Auth, AuthClient, Client, IntoSplit, Listener, MpscListener, Request, Response, Server,
ServerCtx, ServerExt, UntypedTransportRead, UntypedTransportWrite,
};
use log::*;
use std::{collections::HashMap, io, sync::Arc};
use tokio::{
sync::{mpsc, Mutex, RwLock},
task::JoinHandle,
};
mod config;
pub use config::*;
mod connection;
pub use connection::*;
mod ext;
pub use ext::*;
mod handler;
pub use handler::*;
mod r#ref;
pub use r#ref::*;
router!(DistantManagerRouter {
auth_transport: Response<Auth> => Request<Auth>,
manager_transport: Request<ManagerRequest> => Response<ManagerResponse>,
});
/// Represents a manager of multiple distant server connections
pub struct DistantManager {
/// Receives authentication clients to feed into local data of server
auth_client_rx: Mutex<mpsc::Receiver<AuthClient>>,
/// Configuration settings for the server
config: DistantManagerConfig,
/// Mapping of connection id -> connection
connections: RwLock<HashMap<ConnectionId, DistantManagerConnection>>,
/// Handlers for launch requests
launch_handlers: Arc<RwLock<HashMap<String, BoxedLaunchHandler>>>,
/// Handlers for connect requests
connect_handlers: Arc<RwLock<HashMap<String, BoxedConnectHandler>>>,
/// Primary task of server
task: JoinHandle<()>,
}
impl DistantManager {
/// Initializes a new instance of [`DistantManagerServer`] using the provided [`UntypedTransport`]
pub fn start<L, T>(
mut config: DistantManagerConfig,
mut listener: L,
) -> io::Result<DistantManagerRef>
where
L: Listener<Output = T> + 'static,
T: IntoSplit + Send + 'static,
T::Read: UntypedTransportRead + 'static,
T::Write: UntypedTransportWrite + 'static,
{
let (conn_tx, mpsc_listener) = MpscListener::channel(config.connection_buffer_size);
let (auth_client_tx, auth_client_rx) = mpsc::channel(1);
// Spawn task that uses our input listener to get both auth and manager events,
// forwarding manager events to the internal mpsc listener
let task = tokio::spawn(async move {
while let Ok(transport) = listener.accept().await {
let DistantManagerRouter {
auth_transport,
manager_transport,
..
} = DistantManagerRouter::new(transport);
let (writer, reader) = auth_transport.into_split();
let client = match Client::new(writer, reader) {
Ok(client) => client,
Err(x) => {
error!("Creating auth client failed: {}", x);
continue;
}
};
let auth_client = AuthClient::from(client);
// Forward auth client for new connection in server
if auth_client_tx.send(auth_client).await.is_err() {
break;
}
// Forward connected and routed transport to server
if conn_tx.send(manager_transport.into_split()).await.is_err() {
break;
}
}
});
let launch_handlers = Arc::new(RwLock::new(config.launch_handlers.drain().collect()));
let weak_launch_handlers = Arc::downgrade(&launch_handlers);
let connect_handlers = Arc::new(RwLock::new(config.connect_handlers.drain().collect()));
let weak_connect_handlers = Arc::downgrade(&connect_handlers);
let server_ref = Self {
auth_client_rx: Mutex::new(auth_client_rx),
config,
launch_handlers,
connect_handlers,
connections: RwLock::new(HashMap::new()),
task,
}
.start(mpsc_listener)?;
Ok(DistantManagerRef {
launch_handlers: weak_launch_handlers,
connect_handlers: weak_connect_handlers,
inner: server_ref,
})
}
/// Launches a new server at the specified `destination` using the given `options` information
/// and authentication client (if needed) to retrieve additional information needed to
/// enter the destination prior to starting the server, returning the destination of the
/// launched server
async fn launch(
&self,
destination: Destination,
options: Map,
auth: Option<&mut AuthClient>,
) -> io::Result<Destination> {
let auth = auth.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Authentication client not initialized",
)
})?;
let scheme = match destination.scheme.as_deref() {
Some(scheme) => {
trace!("Using scheme {}", scheme);
scheme
}
None => {
trace!(
"Using fallback scheme of {}",
self.config.launch_fallback_scheme.as_str()
);
self.config.launch_fallback_scheme.as_str()
}
}
.to_lowercase();
let credentials = {
let lock = self.launch_handlers.read().await;
let handler = lock.get(&scheme).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("No launch handler registered for {}", scheme),
)
})?;
handler.launch(&destination, &options, auth).await?
};
Ok(credentials)
}
/// Connects to a new server at the specified `destination` using the given `options` information
/// and authentication client (if needed) to retrieve additional information needed to
/// establish the connection to the server
async fn connect(
&self,
destination: Destination,
options: Map,
auth: Option<&mut AuthClient>,
) -> io::Result<ConnectionId> {
let auth = auth.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Authentication client not initialized",
)
})?;
let scheme = match destination.scheme.as_deref() {
Some(scheme) => {
trace!("Using scheme {}", scheme);
scheme
}
None => {
trace!(
"Using fallback scheme of {}",
self.config.connect_fallback_scheme.as_str()
);
self.config.connect_fallback_scheme.as_str()
}
}
.to_lowercase();
let (writer, reader) = {
let lock = self.connect_handlers.read().await;
let handler = lock.get(&scheme).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("No connect handler registered for {}", scheme),
)
})?;
handler.connect(&destination, &options, auth).await?
};
let connection = DistantManagerConnection::new(destination, options, writer, reader);
let id = connection.id;
self.connections.write().await.insert(id, connection);
Ok(id)
}
/// Retrieves the list of supported capabilities for this manager
async fn capabilities(&self) -> io::Result<ManagerCapabilities> {
Ok(ManagerCapabilities::all())
}
/// Retrieves information about the connection to the server with the specified `id`
async fn info(&self, id: ConnectionId) -> io::Result<ConnectionInfo> {
match self.connections.read().await.get(&id) {
Some(connection) => Ok(ConnectionInfo {
id: connection.id,
destination: connection.destination.clone(),
options: connection.options.clone(),
}),
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"No connection found",
)),
}
}
/// Retrieves a list of connections to servers
async fn list(&self) -> io::Result<ConnectionList> {
Ok(ConnectionList(
self.connections
.read()
.await
.values()
.map(|conn| (conn.id, conn.destination.clone()))
.collect(),
))
}
/// Kills the connection to the server with the specified `id`
async fn kill(&self, id: ConnectionId) -> io::Result<()> {
match self.connections.write().await.remove(&id) {
Some(_) => Ok(()),
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"No connection found",
)),
}
}
}
#[derive(Default)]
pub struct DistantManagerServerConnection {
/// Authentication client that manager can use when establishing a new connection
/// and needing to get authentication details from the client to move forward
auth_client: Option<Mutex<AuthClient>>,
/// Holds on to open channels feeding data back from a server to some connected client,
/// enabling us to cancel the tasks on demand
channels: RwLock<HashMap<ChannelId, DistantManagerChannel>>,
}
#[async_trait]
impl Server for DistantManager {
type Request = ManagerRequest;
type Response = ManagerResponse;
type LocalData = DistantManagerServerConnection;
async fn on_accept(&self, local_data: &mut Self::LocalData) {
local_data.auth_client = self
.auth_client_rx
.lock()
.await
.recv()
.await
.map(Mutex::new);
// Enable jit handshake
if let Some(auth_client) = local_data.auth_client.as_ref() {
auth_client.lock().await.set_jit_handshake(true);
}
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
let ServerCtx {
connection_id,
request,
reply,
local_data,
} = ctx;
let response = match request.payload {
ManagerRequest::Capabilities {} => match self.capabilities().await {
Ok(supported) => ManagerResponse::Capabilities { supported },
Err(x) => ManagerResponse::Error(x.into()),
},
ManagerRequest::Launch {
destination,
options,
} => {
let mut auth = match local_data.auth_client.as_ref() {
Some(client) => Some(client.lock().await),
None => None,
};
match self
.launch(*destination, options, auth.as_deref_mut())
.await
{
Ok(destination) => ManagerResponse::Launched { destination },
Err(x) => ManagerResponse::Error(x.into()),
}
}
ManagerRequest::Connect {
destination,
options,
} => {
let mut auth = match local_data.auth_client.as_ref() {
Some(client) => Some(client.lock().await),
None => None,
};
match self
.connect(*destination, options, auth.as_deref_mut())
.await
{
Ok(id) => ManagerResponse::Connected { id },
Err(x) => ManagerResponse::Error(x.into()),
}
}
ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) {
Some(connection) => match connection.open_channel(reply.clone()).await {
Ok(channel) => {
let id = channel.id();
local_data.channels.write().await.insert(id, channel);
ManagerResponse::ChannelOpened { id }
}
Err(x) => ManagerResponse::Error(x.into()),
},
None => ManagerResponse::Error(
io::Error::new(io::ErrorKind::NotConnected, "Connection does not exist").into(),
),
},
ManagerRequest::Channel { id, request } => {
match local_data.channels.read().await.get(&id) {
// TODO: For now, we are NOT sending back a response to acknowledge
// a successful channel send. We could do this in order for
// the client to listen for a complete send, but is it worth it?
Some(channel) => match channel.send(request).await {
Ok(_) => return,
Err(x) => ManagerResponse::Error(x.into()),
},
None => ManagerResponse::Error(
io::Error::new(
io::ErrorKind::NotConnected,
"Channel is not open or does not exist",
)
.into(),
),
}
}
ManagerRequest::CloseChannel { id } => {
match local_data.channels.write().await.remove(&id) {
Some(channel) => match channel.close().await {
Ok(_) => ManagerResponse::ChannelClosed { id },
Err(x) => ManagerResponse::Error(x.into()),
},
None => ManagerResponse::Error(
io::Error::new(
io::ErrorKind::NotConnected,
"Channel is not open or does not exist",
)
.into(),
),
}
}
ManagerRequest::Info { id } => match self.info(id).await {
Ok(info) => ManagerResponse::Info(info),
Err(x) => ManagerResponse::Error(x.into()),
},
ManagerRequest::List => match self.list().await {
Ok(list) => ManagerResponse::List(list),
Err(x) => ManagerResponse::Error(x.into()),
},
ManagerRequest::Kill { id } => match self.kill(id).await {
Ok(()) => ManagerResponse::Killed,
Err(x) => ManagerResponse::Error(x.into()),
},
ManagerRequest::Shutdown => {
if let Err(x) = reply.send(ManagerResponse::Shutdown).await {
error!("[Conn {}] {}", connection_id, x);
}
// Clear out handler state in order to trigger drops
self.launch_handlers.write().await.clear();
self.connect_handlers.write().await.clear();
// Shutdown the primary server task
self.task.abort();
// TODO: Perform a graceful shutdown instead of this?
// Review https://tokio.rs/tokio/topics/shutdown
std::process::exit(0);
}
};
if let Err(x) = reply.send(response).await {
error!("[Conn {}] {}", connection_id, x);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use distant_net::{
AuthClient, FramedTransport, HeapAuthServer, InmemoryTransport, IntoSplit, MappedListener,
OneshotListener, PlainCodec, ServerExt, ServerRef,
};
/// Create a new server, bypassing the start loop
fn setup() -> DistantManager {
let (_, rx) = mpsc::channel(1);
DistantManager {
auth_client_rx: Mutex::new(rx),
config: Default::default(),
connections: RwLock::new(HashMap::new()),
launch_handlers: Arc::new(RwLock::new(HashMap::new())),
connect_handlers: Arc::new(RwLock::new(HashMap::new())),
task: tokio::spawn(async move {}),
}
}
/// Creates a connected [`AuthClient`] with a launched auth server that blindly responds
fn auth_client_server() -> (AuthClient, Box<dyn ServerRef>) {
let (t1, t2) = FramedTransport::pair(1);
let client = AuthClient::from(Client::from_framed_transport(t1).unwrap());
// Create a server that does nothing, but will support
let server = HeapAuthServer {
on_challenge: Box::new(|_, _| Vec::new()),
on_verify: Box::new(|_, _| false),
on_info: Box::new(|_| ()),
on_error: Box::new(|_, _| ()),
}
.start(MappedListener::new(OneshotListener::from_value(t2), |t| {
t.into_split()
}))
.unwrap();
(client, server)
}
fn dummy_distant_writer_reader() -> (BoxedDistantWriter, BoxedDistantReader) {
setup_distant_writer_reader().0
}
/// Creates a writer & reader with a connected transport
fn setup_distant_writer_reader() -> (
(BoxedDistantWriter, BoxedDistantReader),
FramedTransport<InmemoryTransport, PlainCodec>,
) {
let (t1, t2) = FramedTransport::pair(1);
let (writer, reader) = t1.into_split();
((Box::new(writer), Box::new(reader)), t2)
}
#[tokio::test]
async fn launch_should_fail_if_destination_scheme_is_unsupported() {
let server = setup();
let destination = "scheme://host".parse::<Destination>().unwrap();
let options = "".parse::<Map>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let err = server
.launch(destination, options, Some(&mut auth))
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
}
#[tokio::test]
async fn launch_should_fail_if_handler_tied_to_scheme_fails() {
let server = setup();
let handler: Box<dyn LaunchHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
});
server
.launch_handlers
.write()
.await
.insert("scheme".to_string(), handler);
let destination = "scheme://host".parse::<Destination>().unwrap();
let options = "".parse::<Map>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let err = server
.launch(destination, options, Some(&mut auth))
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::Other);
assert_eq!(err.to_string(), "test failure");
}
#[tokio::test]
async fn launch_should_return_new_destination_on_success() {
let server = setup();
let handler: Box<dyn LaunchHandler> = {
Box::new(|_: &_, _: &_, _: &mut _| async {
Ok("scheme2://host2".parse::<Destination>().unwrap())
})
};
server
.launch_handlers
.write()
.await
.insert("scheme".to_string(), handler);
let destination = "scheme://host".parse::<Destination>().unwrap();
let options = "key=value".parse::<Map>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let destination = server
.launch(destination, options, Some(&mut auth))
.await
.unwrap();
assert_eq!(
destination,
"scheme2://host2".parse::<Destination>().unwrap()
);
}
#[tokio::test]
async fn connect_should_fail_if_destination_scheme_is_unsupported() {
let server = setup();
let destination = "scheme://host".parse::<Destination>().unwrap();
let options = "".parse::<Map>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let err = server
.connect(destination, options, Some(&mut auth))
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
}
#[tokio::test]
async fn connect_should_fail_if_handler_tied_to_scheme_fails() {
let server = setup();
let handler: Box<dyn ConnectHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
});
server
.connect_handlers
.write()
.await
.insert("scheme".to_string(), handler);
let destination = "scheme://host".parse::<Destination>().unwrap();
let options = "".parse::<Map>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let err = server
.connect(destination, options, Some(&mut auth))
.await
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::Other);
assert_eq!(err.to_string(), "test failure");
}
#[tokio::test]
async fn connect_should_return_id_of_new_connection_on_success() {
let server = setup();
let handler: Box<dyn ConnectHandler> =
Box::new(|_: &_, _: &_, _: &mut _| async { Ok(dummy_distant_writer_reader()) });
server
.connect_handlers
.write()
.await
.insert("scheme".to_string(), handler);
let destination = "scheme://host".parse::<Destination>().unwrap();
let options = "key=value".parse::<Map>().unwrap();
let (mut auth, _auth_server) = auth_client_server();
let id = server
.connect(destination, options, Some(&mut auth))
.await
.unwrap();
let lock = server.connections.read().await;
let connection = lock.get(&id).unwrap();
assert_eq!(connection.id, id);
assert_eq!(connection.destination, "scheme://host");
assert_eq!(connection.options, "key=value".parse().unwrap());
}
#[tokio::test]
async fn info_should_fail_if_no_connection_found_for_specified_id() {
let server = setup();
let err = server.info(999).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
}
#[tokio::test]
async fn info_should_return_information_about_established_connection() {
let server = setup();
let (writer, reader) = dummy_distant_writer_reader();
let connection = DistantManagerConnection::new(
"scheme://host".parse().unwrap(),
"key=value".parse().unwrap(),
writer,
reader,
);
let id = connection.id;
server.connections.write().await.insert(id, connection);
let info = server.info(id).await.unwrap();
assert_eq!(
info,
ConnectionInfo {
id,
destination: "scheme://host".parse().unwrap(),
options: "key=value".parse().unwrap(),
}
);
}
#[tokio::test]
async fn list_should_return_empty_connection_list_if_no_established_connections() {
let server = setup();
let list = server.list().await.unwrap();
assert_eq!(list, ConnectionList(HashMap::new()));
}
#[tokio::test]
async fn list_should_return_a_list_of_established_connections() {
let server = setup();
let (writer, reader) = dummy_distant_writer_reader();
let connection = DistantManagerConnection::new(
"scheme://host".parse().unwrap(),
"key=value".parse().unwrap(),
writer,
reader,
);
let id_1 = connection.id;
server.connections.write().await.insert(id_1, connection);
let (writer, reader) = dummy_distant_writer_reader();
let connection = DistantManagerConnection::new(
"other://host2".parse().unwrap(),
"key=value".parse().unwrap(),
writer,
reader,
);
let id_2 = connection.id;
server.connections.write().await.insert(id_2, connection);
let list = server.list().await.unwrap();
assert_eq!(
list.get(&id_1).unwrap(),
&"scheme://host".parse::<Destination>().unwrap()
);
assert_eq!(
list.get(&id_2).unwrap(),
&"other://host2".parse::<Destination>().unwrap()
);
}
#[tokio::test]
async fn kill_should_fail_if_no_connection_found_for_specified_id() {
let server = setup();
let err = server.kill(999).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
}
#[tokio::test]
async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() {
let server = setup();
let (writer, reader) = dummy_distant_writer_reader();
let connection = DistantManagerConnection::new(
"scheme://host".parse().unwrap(),
"key=value".parse().unwrap(),
writer,
reader,
);
let id = connection.id;
server.connections.write().await.insert(id, connection);
server.kill(id).await.unwrap();
let lock = server.connections.read().await;
assert!(!lock.contains_key(&id), "Connection still exists");
}
}

@ -1,202 +0,0 @@
use crate::{
data::Map,
manager::{
data::{ChannelId, ConnectionId, Destination},
BoxedDistantReader, BoxedDistantWriter,
},
DistantMsg, DistantRequestData, DistantResponseData, ManagerResponse,
};
use distant_net::{Request, Response, ServerReply};
use log::*;
use std::{collections::HashMap, io};
use tokio::{sync::mpsc, task::JoinHandle};
/// Represents a connection a distant manager has with some distant-compatible server
pub struct DistantManagerConnection {
pub id: ConnectionId,
pub destination: Destination,
pub options: Map,
tx: mpsc::Sender<StateMachine>,
reader_task: JoinHandle<()>,
writer_task: JoinHandle<()>,
}
#[derive(Clone)]
pub struct DistantManagerChannel {
channel_id: ChannelId,
tx: mpsc::Sender<StateMachine>,
}
impl DistantManagerChannel {
pub fn id(&self) -> ChannelId {
self.channel_id
}
pub async fn send(&self, request: Request<DistantMsg<DistantRequestData>>) -> io::Result<()> {
let channel_id = self.channel_id;
self.tx
.send(StateMachine::Write {
id: channel_id,
request,
})
.await
.map_err(|x| {
io::Error::new(
io::ErrorKind::BrokenPipe,
format!("channel {} send failed: {}", channel_id, x),
)
})
}
pub async fn close(&self) -> io::Result<()> {
let channel_id = self.channel_id;
self.tx
.send(StateMachine::Unregister { id: channel_id })
.await
.map_err(|x| {
io::Error::new(
io::ErrorKind::BrokenPipe,
format!("channel {} close failed: {}", channel_id, x),
)
})
}
}
enum StateMachine {
Register {
id: ChannelId,
reply: ServerReply<ManagerResponse>,
},
Unregister {
id: ChannelId,
},
Read {
response: Response<DistantMsg<DistantResponseData>>,
},
Write {
id: ChannelId,
request: Request<DistantMsg<DistantRequestData>>,
},
}
impl DistantManagerConnection {
pub fn new(
destination: Destination,
options: Map,
mut writer: BoxedDistantWriter,
mut reader: BoxedDistantReader,
) -> Self {
let connection_id = rand::random();
let (tx, mut rx) = mpsc::channel(1);
let reader_task = {
let tx = tx.clone();
tokio::spawn(async move {
loop {
match reader.read().await {
Ok(Some(response)) => {
if tx.send(StateMachine::Read { response }).await.is_err() {
break;
}
}
Ok(None) => break,
Err(x) => {
error!("[Conn {}] {}", connection_id, x);
continue;
}
}
}
})
};
let writer_task = tokio::spawn(async move {
let mut registered = HashMap::new();
while let Some(state_machine) = rx.recv().await {
match state_machine {
StateMachine::Register { id, reply } => {
registered.insert(id, reply);
}
StateMachine::Unregister { id } => {
registered.remove(&id);
}
StateMachine::Read { mut response } => {
// Split {channel id}_{request id} back into pieces and
// update the origin id to match the request id only
let channel_id = match response.origin_id.split_once('_') {
Some((cid_str, oid_str)) => {
if let Ok(cid) = cid_str.parse::<ChannelId>() {
response.origin_id = oid_str.to_string();
cid
} else {
continue;
}
}
None => continue,
};
if let Some(reply) = registered.get(&channel_id) {
let response = ManagerResponse::Channel {
id: channel_id,
response,
};
if let Err(x) = reply.send(response).await {
error!("[Conn {}] {}", connection_id, x);
}
}
}
StateMachine::Write { id, request } => {
// Combine channel id with request id so we can properly forward
// the response containing this in the origin id
let request = Request {
id: format!("{}_{}", id, request.id),
payload: request.payload,
};
if let Err(x) = writer.write(request).await {
error!("[Conn {}] {}", connection_id, x);
}
}
}
}
});
Self {
id: connection_id,
destination,
options,
tx,
reader_task,
writer_task,
}
}
pub async fn open_channel(
&self,
reply: ServerReply<ManagerResponse>,
) -> io::Result<DistantManagerChannel> {
let channel_id = rand::random();
self.tx
.send(StateMachine::Register {
id: channel_id,
reply,
})
.await
.map_err(|x| {
io::Error::new(
io::ErrorKind::BrokenPipe,
format!("open_channel failed: {}", x),
)
})?;
Ok(DistantManagerChannel {
channel_id,
tx: self.tx.clone(),
})
}
}
impl Drop for DistantManagerConnection {
fn drop(&mut self) {
self.reader_task.abort();
self.writer_task.abort();
}
}

@ -1,14 +0,0 @@
mod tcp;
pub use tcp::*;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
pub use unix::*;
#[cfg(windows)]
mod windows;
#[cfg(windows)]
pub use windows::*;

@ -1,30 +0,0 @@
use crate::{DistantManager, DistantManagerConfig};
use distant_net::{
Codec, FramedTransport, IntoSplit, MappedListener, PortRange, TcpListener, TcpServerRef,
};
use std::{io, net::IpAddr};
impl DistantManager {
/// Start a new server by binding to the given IP address and one of the ports in the
/// specified range, mapping all connections to use the given codec
pub async fn start_tcp<P, C>(
config: DistantManagerConfig,
addr: IpAddr,
port: P,
codec: C,
) -> io::Result<TcpServerRef>
where
P: Into<PortRange> + Send,
C: Codec + Send + Sync + 'static,
{
let listener = TcpListener::bind(addr, port).await?;
let port = listener.port();
let listener = MappedListener::new(listener, move |transport| {
let transport = FramedTransport::new(transport, codec.clone());
transport.into_split()
});
let inner = DistantManager::start(config, listener)?;
Ok(TcpServerRef::new(addr, port, Box::new(inner)))
}
}

@ -1,50 +0,0 @@
use crate::{DistantManager, DistantManagerConfig};
use distant_net::{
Codec, FramedTransport, IntoSplit, MappedListener, UnixSocketListener, UnixSocketServerRef,
};
use std::{io, path::Path};
impl DistantManager {
/// Start a new server using the specified path as a unix socket using default unix socket file
/// permissions
pub async fn start_unix_socket<P, C>(
config: DistantManagerConfig,
path: P,
codec: C,
) -> io::Result<UnixSocketServerRef>
where
P: AsRef<Path> + Send,
C: Codec + Send + Sync + 'static,
{
Self::start_unix_socket_with_permissions(
config,
path,
codec,
UnixSocketListener::default_unix_socket_file_permissions(),
)
.await
}
/// Start a new server using the specified path as a unix socket and `mode` as the unix socket
/// file permissions
pub async fn start_unix_socket_with_permissions<P, C>(
config: DistantManagerConfig,
path: P,
codec: C,
mode: u32,
) -> io::Result<UnixSocketServerRef>
where
P: AsRef<Path> + Send,
C: Codec + Send + Sync + 'static,
{
let listener = UnixSocketListener::bind_with_permissions(path, mode).await?;
let path = listener.path().to_path_buf();
let listener = MappedListener::new(listener, move |transport| {
let transport = FramedTransport::new(transport, codec.clone());
transport.into_split()
});
let inner = DistantManager::start(config, listener)?;
Ok(UnixSocketServerRef::new(path, Box::new(inner)))
}
}

@ -1,48 +0,0 @@
use crate::{DistantManager, DistantManagerConfig};
use distant_net::{
Codec, FramedTransport, IntoSplit, MappedListener, WindowsPipeListener, WindowsPipeServerRef,
};
use std::{
ffi::{OsStr, OsString},
io,
};
impl DistantManager {
/// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec
pub async fn start_local_named_pipe<N, C>(
config: DistantManagerConfig,
name: N,
codec: C,
) -> io::Result<WindowsPipeServerRef>
where
Self: Sized,
N: AsRef<OsStr> + Send,
C: Codec + Send + Sync + 'static,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::start_named_pipe(config, addr, codec).await
}
/// Start a new server at the specified pipe address using the given codec
pub async fn start_named_pipe<A, C>(
config: DistantManagerConfig,
addr: A,
codec: C,
) -> io::Result<WindowsPipeServerRef>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + Sync + 'static,
{
let a = addr.as_ref();
let listener = WindowsPipeListener::bind(a)?;
let addr = listener.addr().to_os_string();
let listener = MappedListener::new(listener, move |transport| {
let transport = FramedTransport::new(transport, codec.clone());
transport.into_split()
});
let inner = DistantManager::start(config, listener)?;
Ok(WindowsPipeServerRef::new(addr, Box::new(inner)))
}
}

@ -1,68 +0,0 @@
use crate::{
data::Map, manager::data::Destination, DistantMsg, DistantRequestData, DistantResponseData,
};
use async_trait::async_trait;
use distant_net::{AuthClient, Request, Response, TypedAsyncRead, TypedAsyncWrite};
use std::{future::Future, io};
pub type BoxedDistantWriter =
Box<dyn TypedAsyncWrite<Request<DistantMsg<DistantRequestData>>> + Send>;
pub type BoxedDistantReader =
Box<dyn TypedAsyncRead<Response<DistantMsg<DistantResponseData>>> + Send>;
pub type BoxedDistantWriterReader = (BoxedDistantWriter, BoxedDistantReader);
pub type BoxedLaunchHandler = Box<dyn LaunchHandler>;
pub type BoxedConnectHandler = Box<dyn ConnectHandler>;
/// Used to launch a server at the specified destination, returning some result as a vec of bytes
#[async_trait]
pub trait LaunchHandler: Send + Sync {
async fn launch(
&self,
destination: &Destination,
options: &Map,
auth_client: &mut AuthClient,
) -> io::Result<Destination>;
}
#[async_trait]
impl<F, R> LaunchHandler for F
where
F: for<'a> Fn(&'a Destination, &'a Map, &'a mut AuthClient) -> R + Send + Sync + 'static,
R: Future<Output = io::Result<Destination>> + Send + 'static,
{
async fn launch(
&self,
destination: &Destination,
options: &Map,
auth_client: &mut AuthClient,
) -> io::Result<Destination> {
self(destination, options, auth_client).await
}
}
/// Used to connect to a destination, returning a connected reader and writer pair
#[async_trait]
pub trait ConnectHandler: Send + Sync {
async fn connect(
&self,
destination: &Destination,
options: &Map,
auth_client: &mut AuthClient,
) -> io::Result<BoxedDistantWriterReader>;
}
#[async_trait]
impl<F, R> ConnectHandler for F
where
F: for<'a> Fn(&'a Destination, &'a Map, &'a mut AuthClient) -> R + Send + Sync + 'static,
R: Future<Output = io::Result<BoxedDistantWriterReader>> + Send + 'static,
{
async fn connect(
&self,
destination: &Destination,
options: &Map,
auth_client: &mut AuthClient,
) -> io::Result<BoxedDistantWriterReader> {
self(destination, options, auth_client).await
}
}

@ -1,73 +0,0 @@
use super::{BoxedConnectHandler, BoxedLaunchHandler, ConnectHandler, LaunchHandler};
use distant_net::{ServerRef, ServerState};
use std::{collections::HashMap, io, sync::Weak};
use tokio::sync::RwLock;
/// Reference to a distant manager's server instance
pub struct DistantManagerRef {
/// Mapping of "scheme" -> handler
pub(crate) launch_handlers: Weak<RwLock<HashMap<String, BoxedLaunchHandler>>>,
/// Mapping of "scheme" -> handler
pub(crate) connect_handlers: Weak<RwLock<HashMap<String, BoxedConnectHandler>>>,
pub(crate) inner: Box<dyn ServerRef>,
}
impl DistantManagerRef {
/// Registers a new [`LaunchHandler`] for the specified scheme (e.g. "distant" or "ssh")
pub async fn register_launch_handler(
&self,
scheme: impl Into<String>,
handler: impl LaunchHandler + 'static,
) -> io::Result<()> {
let handlers = Weak::upgrade(&self.launch_handlers).ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handler reference is no longer available",
)
})?;
handlers
.write()
.await
.insert(scheme.into(), Box::new(handler));
Ok(())
}
/// Registers a new [`ConnectHandler`] for the specified scheme (e.g. "distant" or "ssh")
pub async fn register_connect_handler(
&self,
scheme: impl Into<String>,
handler: impl ConnectHandler + 'static,
) -> io::Result<()> {
let handlers = Weak::upgrade(&self.connect_handlers).ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handler reference is no longer available",
)
})?;
handlers
.write()
.await
.insert(scheme.into(), Box::new(handler));
Ok(())
}
}
impl ServerRef for DistantManagerRef {
fn state(&self) -> &ServerState {
self.inner.state()
}
fn is_finished(&self) -> bool {
self.inner.is_finished()
}
fn abort(&self) {
self.inner.abort();
}
}

@ -1,96 +0,0 @@
use distant_core::{
net::{FramedTransport, InmemoryTransport, IntoSplit, OneshotListener, PlainCodec},
BoxedDistantReader, BoxedDistantWriter, Destination, DistantApiServer, DistantChannelExt,
DistantManager, DistantManagerClient, DistantManagerClientConfig, DistantManagerConfig, Map,
};
use std::io;
/// Creates a client transport and server listener for our tests
/// that are connected together
async fn setup() -> (
FramedTransport<InmemoryTransport, PlainCodec>,
OneshotListener<FramedTransport<InmemoryTransport, PlainCodec>>,
) {
let (t1, t2) = InmemoryTransport::pair(100);
let listener = OneshotListener::from_value(FramedTransport::new(t2, PlainCodec));
let transport = FramedTransport::new(t1, PlainCodec);
(transport, listener)
}
#[tokio::test]
async fn should_be_able_to_establish_a_single_connection_and_communicate() {
let (transport, listener) = setup().await;
let config = DistantManagerConfig::default();
let manager_ref = DistantManager::start(config, listener).expect("Failed to start manager");
// NOTE: To pass in a raw function, we HAVE to specify the types of the parameters manually,
// otherwise we get a compilation error about lifetime mismatches
manager_ref
.register_connect_handler("scheme", |_: &_, _: &_, _: &mut _| async {
use distant_core::net::ServerExt;
let (t1, t2) = FramedTransport::pair(100);
// Spawn a server on one end
let _ = DistantApiServer::local(Default::default())
.unwrap()
.start(OneshotListener::from_value(t2.into_split()))?;
// Create a reader/writer pair on the other end
let (writer, reader) = t1.into_split();
let writer: BoxedDistantWriter = Box::new(writer);
let reader: BoxedDistantReader = Box::new(reader);
Ok((writer, reader))
})
.await
.expect("Failed to register handler");
let config = DistantManagerClientConfig::with_empty_prompts();
let mut client =
DistantManagerClient::new(config, transport).expect("Failed to connect to manager");
// Test establishing a connection to some remote server
let id = client
.connect(
"scheme://host".parse::<Destination>().unwrap(),
"key=value".parse::<Map>().unwrap(),
)
.await
.expect("Failed to connect to a remote server");
// Test retrieving list of connections
let list = client
.list()
.await
.expect("Failed to get list of connections");
assert_eq!(list.len(), 1);
assert_eq!(list.get(&id).unwrap().to_string(), "scheme://host");
// Test retrieving information
let info = client
.info(id)
.await
.expect("Failed to get info about connection");
assert_eq!(info.id, id);
assert_eq!(info.destination.to_string(), "scheme://host");
assert_eq!(info.options, "key=value".parse::<Map>().unwrap());
// Create a new channel and request some data
let mut channel = client
.open_channel(id)
.await
.expect("Failed to open channel");
let _ = channel
.system_info()
.await
.expect("Failed to get system information");
// Test killing a connection
client.kill(id).await.expect("Failed to kill connection");
// Test getting an error to ensure that serialization of that data works,
// which we do by trying to access a connection that no longer exists
let err = client.info(id).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotConnected);
}

@ -2,6 +2,7 @@ use crate::stress::fixtures::*;
use assert_fs::prelude::*;
use distant_core::DistantChannelExt;
use rstest::*;
use test_log::test;
// 64KB is maximum TCP packet size
const MAX_TCP_PACKET_BYTES: usize = 65535;
@ -10,7 +11,7 @@ const MAX_TCP_PACKET_BYTES: usize = 65535;
const LARGE_FILE_LEN: usize = MAX_TCP_PACKET_BYTES * 10;
#[rstest]
#[tokio::test]
#[test(tokio::test)]
async fn should_handle_large_files(#[future] ctx: DistantClientCtx) {
let ctx = ctx.await;
let mut channel = ctx.client.clone_channel();

@ -2,11 +2,12 @@ use crate::stress::fixtures::*;
use assert_fs::prelude::*;
use distant_core::{data::ChangeKindSet, DistantChannelExt};
use rstest::*;
use test_log::test;
const MAX_FILES: usize = 500;
#[rstest]
#[tokio::test]
#[test(tokio::test)]
#[ignore]
async fn should_handle_large_volume_of_file_watching(#[future] ctx: DistantClientCtx) {
let ctx = ctx.await;

@ -1,14 +1,13 @@
use crate::stress::utils;
use distant_core::{DistantApiServer, DistantClient, LocalDistantApi};
use distant_net::{
PortRange, SecretKey, SecretKey32, TcpClientExt, TcpServerExt, XChaCha20Poly1305Codec,
};
use distant_core::net::client::{Client, TcpConnector};
use distant_core::net::common::authentication::{DummyAuthHandler, Verifier};
use distant_core::net::common::PortRange;
use distant_core::net::server::Server;
use distant_core::{DistantApiServerHandler, DistantClient, LocalDistantApi};
use rstest::*;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::sync::mpsc;
const LOG_PATH: &str = "/tmp/test.distant.server.log";
pub struct DistantClientCtx {
pub client: DistantClient,
_done_tx: mpsc::Sender<()>,
@ -18,40 +17,41 @@ impl DistantClientCtx {
pub async fn initialize() -> Self {
let ip_addr = "127.0.0.1".parse().unwrap();
let (done_tx, mut done_rx) = mpsc::channel::<()>(1);
let (started_tx, mut started_rx) = mpsc::channel::<(u16, SecretKey32)>(1);
let (started_tx, mut started_rx) = mpsc::channel::<u16>(1);
tokio::spawn(async move {
let logger = utils::init_logging(LOG_PATH);
let key = SecretKey::default();
let codec = XChaCha20Poly1305Codec::from(key.clone());
if let Ok(api) = LocalDistantApi::initialize(Default::default()) {
if let Ok(api) = LocalDistantApi::initialize() {
let port: PortRange = "0".parse().unwrap();
let port = {
let server_ref = DistantApiServer::new(api)
.start(ip_addr, port, codec)
let handler = DistantApiServerHandler::new(api);
let server_ref = Server::new()
.handler(handler)
.verifier(Verifier::none())
.into_tcp_builder()
.start(ip_addr, port)
.await
.unwrap();
server_ref.port()
};
started_tx.send((port, key)).await.unwrap();
started_tx.send(port).await.unwrap();
let _ = done_rx.recv().await;
}
logger.flush();
logger.shutdown();
});
// Extract our server startup data if we succeeded
let (port, key) = started_rx.recv().await.unwrap();
let port = started_rx.recv().await.unwrap();
// Now initialize our client
let client = DistantClient::connect_timeout(
format!("{}:{}", ip_addr, port).parse().unwrap(),
XChaCha20Poly1305Codec::from(key),
Duration::from_secs(1),
)
let client: DistantClient = Client::build()
.auth_handler(DummyAuthHandler)
.timeout(Duration::from_secs(1))
.connector(TcpConnector::new(
format!("{}:{}", ip_addr, port)
.parse::<SocketAddr>()
.unwrap(),
))
.connect()
.await
.unwrap();

@ -1,3 +1,2 @@
mod distant;
mod fixtures;
mod utils;

@ -1,23 +0,0 @@
use std::path::PathBuf;
/// Initializes logging (should only call once)
pub fn init_logging(path: impl Into<PathBuf>) -> flexi_logger::LoggerHandle {
use flexi_logger::{FileSpec, LevelFilter, LogSpecification, Logger};
let modules = &["distant", "distant_core", "distant_ssh2"];
// Disable logging for everything but our binary, which is based on verbosity
let mut builder = LogSpecification::builder();
builder.default(LevelFilter::Off);
// For each module, configure logging
for module in modules {
builder.module(module, LevelFilter::Trace);
}
// Create our logger, but don't initialize yet
let logger = Logger::with(builder.build())
.format_for_files(flexi_logger::opt_format)
.log_to_file(FileSpec::try_from(path).expect("Failed to create log file spec"));
logger.start().expect("Failed to initialize logger")
}

@ -3,7 +3,7 @@ name = "distant-net"
description = "Network library for distant, providing implementations to support client/server architecture"
categories = ["network-programming"]
keywords = ["api", "async"]
version = "0.19.0"
version = "0.20.0"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021"
homepage = "https://github.com/chipsenkbeil/distant"
@ -16,7 +16,8 @@ async-trait = "0.1.57"
bytes = "1.2.1"
chacha20poly1305 = "0.10.0"
derive_more = { version = "0.99.17", default-features = false, features = ["as_mut", "as_ref", "deref", "deref_mut", "display", "from", "error", "into", "into_iterator", "is_variant", "try_into"] }
futures = "0.3.21"
dyn-clone = "1.0.9"
flate2 = "1.0.24"
hex = "0.4.3"
hkdf = "0.12.3"
log = "0.4.17"
@ -27,11 +28,13 @@ rmp-serde = "1.1.0"
sha2 = "0.10.2"
serde = { version = "1.0.142", features = ["derive"] }
serde_bytes = "0.11.7"
strum = { version = "0.24.1", features = ["derive"] }
tokio = { version = "1.20.1", features = ["full"] }
tokio-util = { version = "0.7.3", features = ["codec"] }
# Optional dependencies based on features
schemars = { version = "0.8.10", optional = true }
[dev-dependencies]
env_logger = "0.9.1"
tempfile = "3.3.0"
test-log = "0.2.11"

@ -1,122 +0,0 @@
use derive_more::Display;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
mod client;
pub use client::*;
mod handshake;
pub use handshake::*;
mod server;
pub use server::*;
/// Represents authentication messages that can be sent over the wire
///
/// NOTE: Must use serde's content attribute with the tag attribute. Just the tag attribute will
/// cause deserialization to fail
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type", content = "data")]
pub enum Auth {
/// Represents a request to perform an authentication handshake,
/// providing the public key and salt from one side in order to
/// derive the shared key
#[serde(rename = "auth_handshake")]
Handshake {
/// Bytes of the public key
#[serde(with = "serde_bytes")]
public_key: PublicKeyBytes,
/// Randomly generated salt
#[serde(with = "serde_bytes")]
salt: Salt,
},
/// Represents the bytes of an encrypted message
///
/// Underneath, will be one of either [`AuthRequest`] or [`AuthResponse`]
#[serde(rename = "auth_msg")]
Msg {
#[serde(with = "serde_bytes")]
encrypted_payload: Vec<u8>,
},
}
/// Represents authentication messages that act as initiators such as providing
/// a challenge, verifying information, presenting information, or highlighting an error
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthRequest {
/// Represents a challenge comprising a series of questions to be presented
Challenge {
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
},
/// Represents an ask to verify some information
Verify { kind: AuthVerifyKind, text: String },
/// Represents some information to be presented
Info { text: String },
/// Represents some error that occurred
Error { kind: AuthErrorKind, text: String },
}
/// Represents authentication messages that are responses to auth requests such
/// as answers to challenges or verifying information
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthResponse {
/// Represents the answers to a previously-asked challenge
Challenge { answers: Vec<String> },
/// Represents the answer to a previously-asked verify
Verify { valid: bool },
}
/// Represents the type of verification being requested
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum AuthVerifyKind {
/// An ask to verify the host such as with SSH
#[display(fmt = "host")]
Host,
}
/// Represents a single question in a challenge
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthQuestion {
/// The text of the question
pub text: String,
/// Any options information specific to a particular auth domain
/// such as including a username and instructions for SSH authentication
pub options: HashMap<String, String>,
}
impl AuthQuestion {
/// Creates a new question without any options data
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
options: HashMap::new(),
}
}
}
/// Represents the type of error encountered during authentication
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthErrorKind {
/// When the answer(s) to a challenge do not pass authentication
FailedChallenge,
/// When verification during authentication fails
/// (e.g. a host is not allowed or blocked)
FailedVerification,
/// When the error is unknown
Unknown,
}

@ -1,817 +0,0 @@
use crate::{
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Client,
Codec, Handshake, XChaCha20Poly1305Codec,
};
use bytes::BytesMut;
use log::*;
use std::{collections::HashMap, io};
pub struct AuthClient {
inner: Client<Auth, Auth>,
codec: Option<XChaCha20Poly1305Codec>,
jit_handshake: bool,
}
impl From<Client<Auth, Auth>> for AuthClient {
fn from(client: Client<Auth, Auth>) -> Self {
Self {
inner: client,
codec: None,
jit_handshake: false,
}
}
}
impl AuthClient {
/// Sends a request to the server to establish an encrypted connection
pub async fn handshake(&mut self) -> io::Result<()> {
let handshake = Handshake::default();
let response = self
.inner
.send(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
})
.await?;
match response.payload {
Auth::Handshake { public_key, salt } => {
let key = handshake.handshake(public_key, salt)?;
self.codec.replace(XChaCha20Poly1305Codec::new(&key));
Ok(())
}
Auth::Msg { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected encrypted message during handshake",
)),
}
}
/// Perform a handshake only if jit is enabled and no handshake has succeeded yet
async fn jit_handshake(&mut self) -> io::Result<()> {
if self.will_jit_handshake() && !self.is_ready() {
self.handshake().await
} else {
Ok(())
}
}
/// Returns true if client has successfully performed a handshake
/// and is ready to communicate with the server
pub fn is_ready(&self) -> bool {
self.codec.is_some()
}
/// Returns true if this client will perform a handshake just-in-time (JIT) prior to making a
/// request in the scenario where the client has not already performed a handshake
#[inline]
pub fn will_jit_handshake(&self) -> bool {
self.jit_handshake
}
/// Sets the jit flag on this client with `true` indicating that this client will perform a
/// handshake just-in-time (JIT) prior to making a request in the scenario where the client has
/// not already performed a handshake
#[inline]
pub fn set_jit_handshake(&mut self, flag: bool) {
self.jit_handshake = flag;
}
/// Provides a challenge to the server and returns the answers to the questions
/// asked by the client
pub async fn challenge(
&mut self,
questions: Vec<AuthQuestion>,
options: HashMap<String, String>,
) -> io::Result<Vec<String>> {
trace!(
"AuthClient::challenge(questions = {:?}, options = {:?})",
questions,
options
);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Challenge { questions, options };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
match response.payload {
Auth::Msg { encrypted_payload } => {
match self.decrypt_and_deserialize(&encrypted_payload)? {
AuthResponse::Challenge { answers } => Ok(answers),
AuthResponse::Verify { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected verify response during challenge",
)),
}
}
Auth::Handshake { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected handshake during challenge",
)),
}
}
/// Provides a verification request to the server and returns whether or not
/// the server approved
pub async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool> {
trace!("AuthClient::verify(kind = {:?}, text = {:?})", kind, text);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Verify { kind, text };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
match response.payload {
Auth::Msg { encrypted_payload } => {
match self.decrypt_and_deserialize(&encrypted_payload)? {
AuthResponse::Verify { valid } => Ok(valid),
AuthResponse::Challenge { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected challenge response during verify",
)),
}
}
Auth::Handshake { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Got unexpected handshake during verify",
)),
}
}
/// Provides information to the server to use as it pleases with no response expected
pub async fn info(&mut self, text: String) -> io::Result<()> {
trace!("AuthClient::info(text = {:?})", text);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Info { text };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
self.inner.fire(Auth::Msg { encrypted_payload }).await
}
/// Provides an error to the server to use as it pleases with no response expected
pub async fn error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
trace!("AuthClient::error(kind = {:?}, text = {:?})", kind, text);
// Perform JIT handshake if enabled
self.jit_handshake().await?;
let payload = AuthRequest::Error { kind, text };
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
self.inner.fire(Auth::Msg { encrypted_payload }).await
}
fn serialize_and_encrypt(&mut self, payload: &AuthRequest) -> io::Result<Vec<u8>> {
let codec = self.codec.as_mut().ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (client encrypt message)",
)
})?;
let mut encryped_payload = BytesMut::new();
let payload = utils::serialize_to_vec(payload)?;
codec.encode(&payload, &mut encryped_payload)?;
Ok(encryped_payload.freeze().to_vec())
}
fn decrypt_and_deserialize(&mut self, payload: &[u8]) -> io::Result<AuthResponse> {
let codec = self.codec.as_mut().ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (client decrypt message)",
)
})?;
let mut payload = BytesMut::from(payload);
match codec.decode(&mut payload)? {
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Client, FramedTransport, Request, Response, TypedAsyncRead, TypedAsyncWrite};
use serde::{de::DeserializeOwned, Serialize};
const TIMEOUT_MILLIS: u64 = 100;
#[tokio::test]
async fn handshake_should_fail_if_get_unexpected_response_from_server() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move { client.handshake().await });
// Get the request, but send a bad response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Handshake { .. } => server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: Vec::new(),
},
))
.await
.unwrap(),
_ => panic!("Server received unexpected payload"),
}
let result = task.await.unwrap();
assert!(result.is_err(), "Handshake succeeded unexpectedly")
}
#[tokio::test]
async fn challenge_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move { client.challenge(Vec::new(), HashMap::new()).await });
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Challenge succeeded unexpectedly")
}
#[tokio::test]
async fn challenge_should_fail_if_receive_wrong_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.challenge(
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key2".to_string(), "value2".to_string())]
.into_iter()
.collect(),
},
],
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
)
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a challenge request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Challenge { .. } => {
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Verify { valid: true },
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Challenge succeeded unexpectedly")
}
#[tokio::test]
async fn challenge_should_return_answers_received_from_server() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.challenge(
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key2".to_string(), "value2".to_string())]
.into_iter()
.collect(),
},
],
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
)
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a challenge request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Challenge { questions, options } => {
assert_eq!(
questions,
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key2".to_string(), "value2".to_string())]
.into_iter()
.collect(),
},
],
);
assert_eq!(
options,
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
);
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Challenge {
answers: vec![
"answer1".to_string(),
"answer2".to_string(),
],
},
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
let answers = task.await.unwrap().unwrap();
assert_eq!(answers, vec!["answer1".to_string(), "answer2".to_string()]);
}
#[tokio::test]
async fn verify_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client
.verify(AuthVerifyKind::Host, "some text".to_string())
.await
});
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Verify succeeded unexpectedly")
}
#[tokio::test]
async fn verify_should_fail_if_receive_wrong_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.verify(AuthVerifyKind::Host, "some text".to_string())
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a verify request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Verify { .. } => {
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Challenge {
answers: Vec::new(),
},
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Verify succeeded unexpectedly")
}
#[tokio::test]
async fn verify_should_return_valid_bool_received_from_server() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.verify(AuthVerifyKind::Host, "some text".to_string())
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a challenge request and send back wrong response
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Verify { kind, text } => {
assert_eq!(kind, AuthVerifyKind::Host);
assert_eq!(text, "some text");
server
.write(Response::new(
request.id,
Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthResponse::Verify { valid: true },
)
.unwrap(),
},
))
.await
.unwrap();
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
let valid = task.await.unwrap().unwrap();
assert!(valid, "Got verify response, but valid was set incorrectly");
}
#[tokio::test]
async fn info_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move { client.info("some text".to_string()).await });
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Info succeeded unexpectedly")
}
#[tokio::test]
async fn info_should_send_the_server_a_request_but_not_wait_for_a_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client.info("some text".to_string()).await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a request
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Info { text } => {
assert_eq!(text, "some text");
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn error_should_fail_if_handshake_not_finished() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
.await
});
// Wait for a request, failing if we get one as the failure
// should have prevented sending anything, but we should
tokio::select! {
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
match x {
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
Ok(None) => {},
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
}
},
_ = wait_ms(TIMEOUT_MILLIS) => {
panic!("Should have gotten server closure as part of client exit");
}
}
// Verify that we got an error with the method
let result = task.await.unwrap();
assert!(result.is_err(), "Error succeeded unexpectedly")
}
#[tokio::test]
async fn error_should_send_the_server_a_request_but_not_wait_for_a_response() {
let (t, mut server) = FramedTransport::make_test_pair();
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
// We start a separate task for the client to avoid blocking since
// we also need to receive the client's request and respond
let task = tokio::spawn(async move {
client.handshake().await.unwrap();
client
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
.await
});
// Wait for a handshake request and set up our encryption codec
let request: Request<Auth> = server.read().await.unwrap().unwrap();
let mut codec = match request.payload {
Auth::Handshake { public_key, salt } => {
let handshake = Handshake::default();
let key = handshake.handshake(public_key, salt).unwrap();
server
.write(Response::new(
request.id,
Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
},
))
.await
.unwrap();
XChaCha20Poly1305Codec::new(&key)
}
_ => panic!("Server received unexpected payload"),
};
// Wait for a request
let request: Request<Auth> = server.read().await.unwrap().unwrap();
match request.payload {
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthRequest::Error { kind, text } => {
assert_eq!(kind, AuthErrorKind::FailedChallenge);
assert_eq!(text, "some text");
}
_ => panic!("Server received wrong request type"),
}
}
_ => panic!("Server received unexpected payload"),
};
// Verify that we got the right results
task.await.unwrap().unwrap();
}
async fn wait_ms(ms: u64) {
use std::time::Duration;
tokio::time::sleep(Duration::from_millis(ms)).await;
}
fn serialize_and_encrypt<T: Serialize>(
codec: &mut XChaCha20Poly1305Codec,
payload: &T,
) -> io::Result<Vec<u8>> {
let mut encryped_payload = BytesMut::new();
let payload = utils::serialize_to_vec(payload)?;
codec.encode(&payload, &mut encryped_payload)?;
Ok(encryped_payload.freeze().to_vec())
}
fn decrypt_and_deserialize<T: DeserializeOwned>(
codec: &mut XChaCha20Poly1305Codec,
payload: &[u8],
) -> io::Result<T> {
let mut payload = BytesMut::from(payload);
match codec.decode(&mut payload)? {
Some(payload) => utils::deserialize_from_slice::<T>(&payload),
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
}
}
}

@ -1,653 +0,0 @@
use crate::{
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Codec,
Handshake, Server, ServerCtx, XChaCha20Poly1305Codec,
};
use async_trait::async_trait;
use bytes::BytesMut;
use log::*;
use std::{collections::HashMap, io};
use tokio::sync::RwLock;
/// Type signature for a dynamic on_challenge function
pub type AuthChallengeFn =
dyn Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync;
/// Type signature for a dynamic on_verify function
pub type AuthVerifyFn = dyn Fn(AuthVerifyKind, String) -> bool + Send + Sync;
/// Type signature for a dynamic on_info function
pub type AuthInfoFn = dyn Fn(String) + Send + Sync;
/// Type signature for a dynamic on_error function
pub type AuthErrorFn = dyn Fn(AuthErrorKind, String) + Send + Sync;
/// Represents an [`AuthServer`] where all handlers are stored on the heap
pub type HeapAuthServer =
AuthServer<Box<AuthChallengeFn>, Box<AuthVerifyFn>, Box<AuthInfoFn>, Box<AuthErrorFn>>;
/// Server that handles authentication
pub struct AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
where
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
InfoFn: Fn(String) + Send + Sync,
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
{
pub on_challenge: ChallengeFn,
pub on_verify: VerifyFn,
pub on_info: InfoFn,
pub on_error: ErrorFn,
}
#[async_trait]
impl<ChallengeFn, VerifyFn, InfoFn, ErrorFn> Server
for AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
where
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
InfoFn: Fn(String) + Send + Sync,
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
{
type Request = Auth;
type Response = Auth;
type LocalData = RwLock<Option<XChaCha20Poly1305Codec>>;
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
let reply = ctx.reply.clone();
match ctx.request.payload {
Auth::Handshake { public_key, salt } => {
trace!(
"Received handshake request from client, request id = {}",
ctx.request.id
);
let handshake = Handshake::default();
match handshake.handshake(public_key, salt) {
Ok(key) => {
ctx.local_data
.write()
.await
.replace(XChaCha20Poly1305Codec::new(&key));
trace!(
"Sending reciprocal handshake to client, response origin id = {}",
ctx.request.id
);
if let Err(x) = reply
.send(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
})
.await
{
error!("[Conn {}] {}", ctx.connection_id, x);
}
}
Err(x) => {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
}
}
Auth::Msg {
ref encrypted_payload,
} => {
trace!(
"Received auth msg, encrypted payload size = {}",
encrypted_payload.len()
);
// Attempt to decrypt the message so we can understand what to do
let request = match ctx.local_data.write().await.as_mut() {
Some(codec) => {
let mut payload = BytesMut::from(encrypted_payload.as_slice());
match codec.decode(&mut payload) {
Ok(Some(payload)) => {
utils::deserialize_from_slice::<AuthRequest>(&payload)
}
Ok(None) => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
Err(x) => Err(x),
}
}
None => Err(io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (server decrypt message)",
)),
};
let response = match request {
Ok(request) => match request {
AuthRequest::Challenge { questions, options } => {
trace!("Received challenge request");
trace!("questions = {:?}", questions);
trace!("options = {:?}", options);
let answers = (self.on_challenge)(questions, options);
AuthResponse::Challenge { answers }
}
AuthRequest::Verify { kind, text } => {
trace!("Received verify request");
trace!("kind = {:?}", kind);
trace!("text = {:?}", text);
let valid = (self.on_verify)(kind, text);
AuthResponse::Verify { valid }
}
AuthRequest::Info { text } => {
trace!("Received info request");
trace!("text = {:?}", text);
(self.on_info)(text);
return;
}
AuthRequest::Error { kind, text } => {
trace!("Received error request");
trace!("kind = {:?}", kind);
trace!("text = {:?}", text);
(self.on_error)(kind, text);
return;
}
},
Err(x) => {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
};
// Serialize and encrypt the message before sending it back
let encrypted_payload = match ctx.local_data.write().await.as_mut() {
Some(codec) => {
let mut encrypted_payload = BytesMut::new();
// Convert the response into bytes for us to send back
match utils::serialize_to_vec(&response) {
Ok(bytes) => match codec.encode(&bytes, &mut encrypted_payload) {
Ok(_) => Ok(encrypted_payload.freeze().to_vec()),
Err(x) => Err(x),
},
Err(x) => Err(x),
}
}
None => Err(io::Error::new(
io::ErrorKind::Other,
"Handshake must be performed first (server encrypt messaage)",
)),
};
match encrypted_payload {
Ok(encrypted_payload) => {
if let Err(x) = reply.send(Auth::Msg { encrypted_payload }).await {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
}
Err(x) => {
error!("[Conn {}] {}", ctx.connection_id, x);
return;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
IntoSplit, MpscListener, MpscTransport, Request, Response, ServerExt, ServerRef,
TypedAsyncRead, TypedAsyncWrite,
};
use tokio::sync::mpsc;
const TIMEOUT_MILLIS: u64 = 100;
#[tokio::test]
async fn should_not_reply_if_receive_encrypted_msg_without_handshake_first() {
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send an encrypted message before establishing a handshake
t.write(Request::new(Auth::Msg {
encrypted_payload: Vec::new(),
}))
.await
.expect("Failed to send request to server");
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_reply_to_handshake_request_with_new_public_key_and_salt() {
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Wait for a handshake response
tokio::select! {
x = t.read() => {
let response = x.expect("Request failed").expect("Response missing");
match response.payload {
Auth::Handshake { .. } => {},
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
}
}
_ = wait_ms(TIMEOUT_MILLIS) => panic!("Ran out of time waiting on response"),
}
}
#[tokio::test]
async fn should_not_reply_if_receive_invalid_encrypted_msg() {
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send a bad chunk of data
let _codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: vec![1, 2, 3, 4],
}))
.await
.unwrap();
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_challenge_request_and_reply() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */
move |questions, options| {
tx.try_send((questions, options)).unwrap();
vec!["answer1".to_string(), "answer2".to_string()]
},
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Challenge {
questions: vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
},
],
options: vec![("hello".to_string(), "world".to_string())]
.into_iter()
.collect(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let (questions, options) = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(
questions,
vec![
AuthQuestion::new("question1".to_string()),
AuthQuestion {
text: "question2".to_string(),
options: vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
}
]
);
assert_eq!(
options,
vec![("hello".to_string(), "world".to_string())]
.into_iter()
.collect()
);
// Wait for a response and verify that it matches what we expect
tokio::select! {
x = t.read() => {
let response = x.expect("Request failed").expect("Response missing");
match response.payload {
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthResponse::Challenge { answers } =>
assert_eq!(
answers,
vec!["answer1".to_string(), "answer2".to_string()]
),
_ => panic!("Got wrong response for verify"),
}
},
}
}
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_verify_request_and_reply() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */
move |kind, text| {
tx.try_send((kind, text)).unwrap();
true
},
/* on_info */ |_| {},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Verify {
kind: AuthVerifyKind::Host,
text: "some text".to_string(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(kind, AuthVerifyKind::Host);
assert_eq!(text, "some text");
// Wait for a response and verify that it matches what we expect
tokio::select! {
x = t.read() => {
let response = x.expect("Request failed").expect("Response missing");
match response.payload {
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
Auth::Msg { encrypted_payload } => {
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
AuthResponse::Verify { valid } =>
assert!(valid, "Got verify, but valid was wrong"),
_ => panic!("Got wrong response for verify"),
}
},
}
}
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_info_request() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */
move |text| {
tx.try_send(text).unwrap();
},
/* on_error */ |_, _| {},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Info {
text: "some text".to_string(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let text = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(text, "some text");
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
#[tokio::test]
async fn should_invoke_appropriate_function_when_receive_error_request() {
let (tx, mut rx) = mpsc::channel(1);
let (mut t, _) = spawn_auth_server(
/* on_challenge */ |_, _| Vec::new(),
/* on_verify */ |_, _| false,
/* on_info */ |_| {},
/* on_error */
move |kind, text| {
tx.try_send((kind, text)).unwrap();
},
)
.await
.expect("Failed to spawn server");
// Send a handshake
let handshake = Handshake::default();
t.write(Request::new(Auth::Handshake {
public_key: handshake.pk_bytes(),
salt: *handshake.salt(),
}))
.await
.expect("Failed to send request to server");
// Complete handshake
let key = match t.read().await.unwrap().unwrap().payload {
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
};
// Send an error request
let mut codec = XChaCha20Poly1305Codec::new(&key);
t.write(Request::new(Auth::Msg {
encrypted_payload: serialize_and_encrypt(
&mut codec,
&AuthRequest::Error {
kind: AuthErrorKind::FailedChallenge,
text: "some text".to_string(),
},
)
.unwrap(),
}))
.await
.unwrap();
// Verify that the handler was triggered
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
assert_eq!(kind, AuthErrorKind::FailedChallenge);
assert_eq!(text, "some text");
// Wait for a response, failing if we get one
tokio::select! {
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
_ = wait_ms(TIMEOUT_MILLIS) => {}
}
}
async fn wait_ms(ms: u64) {
use std::time::Duration;
tokio::time::sleep(Duration::from_millis(ms)).await;
}
fn serialize_and_encrypt(
codec: &mut XChaCha20Poly1305Codec,
payload: &AuthRequest,
) -> io::Result<Vec<u8>> {
let mut encryped_payload = BytesMut::new();
let payload = utils::serialize_to_vec(payload)?;
codec.encode(&payload, &mut encryped_payload)?;
Ok(encryped_payload.freeze().to_vec())
}
fn decrypt_and_deserialize(
codec: &mut XChaCha20Poly1305Codec,
payload: &[u8],
) -> io::Result<AuthResponse> {
let mut payload = BytesMut::from(payload);
match codec.decode(&mut payload)? {
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Incomplete message received",
)),
}
}
async fn spawn_auth_server<ChallengeFn, VerifyFn, InfoFn, ErrorFn>(
on_challenge: ChallengeFn,
on_verify: VerifyFn,
on_info: InfoFn,
on_error: ErrorFn,
) -> io::Result<(
MpscTransport<Request<Auth>, Response<Auth>>,
Box<dyn ServerRef>,
)>
where
ChallengeFn:
Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync + 'static,
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync + 'static,
InfoFn: Fn(String) + Send + Sync + 'static,
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync + 'static,
{
let server = AuthServer {
on_challenge,
on_verify,
on_info,
on_error,
};
// Create a test listener where we will forward a connection
let (tx, listener) = MpscListener::channel(100);
// Make bounded transport pair and send off one of them to act as our connection
let (transport, connection) = MpscTransport::<Request<Auth>, Response<Auth>>::pair(100);
tx.send(connection.into_split())
.await
.expect("Failed to feed listener a connection");
let server = server.start(listener)?;
Ok((transport, server))
}
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,142 @@
mod tcp;
pub use tcp::*;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
pub use unix::*;
#[cfg(windows)]
mod windows;
#[cfg(windows)]
pub use windows::*;
use crate::client::{Client, ReconnectStrategy, UntypedClient};
use crate::common::{authentication::AuthHandler, Connection, Transport};
use async_trait::async_trait;
use std::{convert, io, time::Duration};
/// Interface that performs the connection to produce a [`Transport`] for use by the [`Client`].
#[async_trait]
pub trait Connector {
/// Type of transport produced by the connection.
type Transport: Transport + 'static;
async fn connect(self) -> io::Result<Self::Transport>;
}
#[async_trait]
impl<T: Transport + 'static> Connector for T {
type Transport = T;
async fn connect(self) -> io::Result<Self::Transport> {
Ok(self)
}
}
/// Builder for a [`Client`] or [`UntypedClient`].
pub struct ClientBuilder<H, C> {
auth_handler: H,
connector: C,
reconnect_strategy: ReconnectStrategy,
timeout: Option<Duration>,
}
impl<H, C> ClientBuilder<H, C> {
pub fn auth_handler<U>(self, auth_handler: U) -> ClientBuilder<U, C> {
ClientBuilder {
auth_handler,
connector: self.connector,
reconnect_strategy: self.reconnect_strategy,
timeout: self.timeout,
}
}
pub fn connector<U>(self, connector: U) -> ClientBuilder<H, U> {
ClientBuilder {
auth_handler: self.auth_handler,
connector,
reconnect_strategy: self.reconnect_strategy,
timeout: self.timeout,
}
}
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> ClientBuilder<H, C> {
ClientBuilder {
auth_handler: self.auth_handler,
connector: self.connector,
reconnect_strategy,
timeout: self.timeout,
}
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self {
auth_handler: self.auth_handler,
connector: self.connector,
reconnect_strategy: self.reconnect_strategy,
timeout: timeout.into(),
}
}
}
impl ClientBuilder<(), ()> {
pub fn new() -> Self {
Self {
auth_handler: (),
reconnect_strategy: ReconnectStrategy::default(),
connector: (),
timeout: None,
}
}
}
impl Default for ClientBuilder<(), ()> {
fn default() -> Self {
Self::new()
}
}
impl<H, C> ClientBuilder<H, C>
where
H: AuthHandler + Send,
C: Connector,
{
/// Establishes a connection with a remote server using the configured [`Transport`]
/// and other settings, returning a new [`UntypedClient`] instance once the connection
/// is fully established and authenticated.
pub async fn connect_untyped(self) -> io::Result<UntypedClient> {
let auth_handler = self.auth_handler;
let retry_strategy = self.reconnect_strategy;
let timeout = self.timeout;
let f = async move {
let transport = match timeout {
Some(duration) => tokio::time::timeout(duration, self.connector.connect())
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)?,
None => self.connector.connect().await?,
};
let connection = Connection::client(transport, auth_handler).await?;
Ok(UntypedClient::spawn(connection, retry_strategy))
};
match timeout {
Some(duration) => tokio::time::timeout(duration, f)
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity),
None => f.await,
}
}
/// Establishes a connection with a remote server using the configured [`Transport`] and other
/// settings, returning a new [`Client`] instance once the connection is fully established and
/// authenticated.
pub async fn connect<T, U>(self) -> io::Result<Client<T, U>> {
Ok(self.connect_untyped().await?.into_typed_client())
}
}

@ -0,0 +1,31 @@
use super::Connector;
use crate::common::TcpTransport;
use async_trait::async_trait;
use std::io;
use tokio::net::ToSocketAddrs;
/// Implementation of [`Connector`] to support connecting via TCP.
pub struct TcpConnector<T> {
addr: T,
}
impl<T> TcpConnector<T> {
pub fn new(addr: T) -> Self {
Self { addr }
}
}
impl<T> From<T> for TcpConnector<T> {
fn from(addr: T) -> Self {
Self::new(addr)
}
}
#[async_trait]
impl<T: ToSocketAddrs + Send> Connector for TcpConnector<T> {
type Transport = TcpTransport;
async fn connect(self) -> io::Result<Self::Transport> {
TcpTransport::connect(self.addr).await
}
}

@ -0,0 +1,30 @@
use super::Connector;
use crate::common::UnixSocketTransport;
use async_trait::async_trait;
use std::{io, path::PathBuf};
/// Implementation of [`Connector`] to support connecting via a Unix socket.
pub struct UnixSocketConnector {
path: PathBuf,
}
impl UnixSocketConnector {
pub fn new(path: impl Into<PathBuf>) -> Self {
Self { path: path.into() }
}
}
impl<T: Into<PathBuf>> From<T> for UnixSocketConnector {
fn from(path: T) -> Self {
Self::new(path)
}
}
#[async_trait]
impl Connector for UnixSocketConnector {
type Transport = UnixSocketTransport;
async fn connect(self) -> io::Result<Self::Transport> {
UnixSocketTransport::connect(self.path).await
}
}

@ -0,0 +1,50 @@
use super::Connector;
use crate::common::WindowsPipeTransport;
use async_trait::async_trait;
use std::ffi::OsString;
use std::io;
/// Implementation of [`Connector`] to support connecting via a Windows named pipe.
pub struct WindowsPipeConnector {
addr: OsString,
pub(crate) local: bool,
}
impl WindowsPipeConnector {
/// Creates a new connector for a non-local pipe using the given `addr`.
pub fn new(addr: impl Into<OsString>) -> Self {
Self {
addr: addr.into(),
local: false,
}
}
/// Creates a new connector for a local pipe using the given `name`.
pub fn local(name: impl Into<OsString>) -> Self {
Self {
addr: name.into(),
local: true,
}
}
}
impl<T: Into<OsString>> From<T> for WindowsPipeConnector {
fn from(addr: T) -> Self {
Self::new(addr)
}
}
#[async_trait]
impl Connector for WindowsPipeConnector {
type Transport = WindowsPipeTransport;
async fn connect(self) -> io::Result<Self::Transport> {
if self.local {
let mut full_addr = OsString::from(r"\\.\pipe\");
full_addr.push(self.addr);
WindowsPipeTransport::connect(full_addr).await
} else {
WindowsPipeTransport::connect(self.addr).await
}
}
}

@ -1,5 +1,7 @@
use crate::{Request, Response};
use std::{convert, io, sync::Weak};
use crate::common::{Request, Response, UntypedRequest, UntypedResponse};
use log::*;
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, fmt, io, marker::PhantomData, sync::Weak};
use tokio::{sync::mpsc, time::Duration};
mod mailbox;
@ -9,26 +11,181 @@ pub use mailbox::*;
const CHANNEL_MAILBOX_CAPACITY: usize = 10000;
/// Represents a sender of requests tied to a session, holding onto a weak reference of
/// mailboxes to relay responses, meaning that once the [`Session`] is closed or dropped,
/// any sent request will no longer be able to receive responses
pub struct Channel<T, U>
/// mailboxes to relay responses, meaning that once the [`Client`] is closed or dropped,
/// any sent request will no longer be able to receive responses.
///
/// [`Client`]: crate::client::Client
pub struct Channel<T, U> {
inner: UntypedChannel,
_request: PhantomData<T>,
_response: PhantomData<U>,
}
// NOTE: Implemented manually to avoid needing clone to be defined on generic types
impl<T, U> Clone for Channel<T, U> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_request: self._request,
_response: self._response,
}
}
}
impl<T, U> fmt::Debug for Channel<T, U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Channel")
.field("tx", &self.inner.tx)
.field("post_office", &self.inner.post_office)
.field("_request", &self._request)
.field("_response", &self._response)
.finish()
}
}
impl<T, U> Channel<T, U>
where
T: Send + Sync,
U: Send + Sync,
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
/// Returns true if no more requests can be transferred
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
/// Consumes this channel, returning an untyped variant
pub fn into_untyped_channel(self) -> UntypedChannel {
self.inner
}
/// Assigns a default mailbox for any response received that does not match another mailbox.
pub async fn assign_default_mailbox(&self, buffer: usize) -> io::Result<Mailbox<Response<U>>> {
Ok(map_to_typed_mailbox(
self.inner.assign_default_mailbox(buffer).await?,
))
}
/// Removes the default mailbox used for unmatched responses such that any response without a
/// matching mailbox will be dropped.
pub async fn remove_default_mailbox(&self) -> io::Result<()> {
self.inner.remove_default_mailbox().await
}
/// Sends a request and returns a mailbox that can receive one or more responses, failing if
/// unable to send a request or if the session's receiving line to the remote server has
/// already been severed
pub async fn mail(&mut self, req: impl Into<Request<T>>) -> io::Result<Mailbox<Response<U>>> {
Ok(map_to_typed_mailbox(
self.inner.mail(req.into().to_untyped_request()?).await?,
))
}
/// Sends a request and returns a mailbox, timing out after duration has passed
pub async fn mail_timeout(
&mut self,
req: impl Into<Request<T>>,
duration: impl Into<Option<Duration>>,
) -> io::Result<Mailbox<Response<U>>> {
match duration.into() {
Some(duration) => tokio::time::timeout(duration, self.mail(req))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity),
None => self.mail(req).await,
}
}
/// Sends a request and waits for a response, failing if unable to send a request or if
/// the session's receiving line to the remote server has already been severed
pub async fn send(&mut self, req: impl Into<Request<T>>) -> io::Result<Response<U>> {
// Send mail and get back a mailbox
let mut mailbox = self.mail(req).await?;
// Wait for first response, and then drop the mailbox
mailbox
.next()
.await
.ok_or_else(|| io::Error::from(io::ErrorKind::ConnectionAborted))
}
/// Sends a request and waits for a response, timing out after duration has passed
pub async fn send_timeout(
&mut self,
req: impl Into<Request<T>>,
duration: impl Into<Option<Duration>>,
) -> io::Result<Response<U>> {
match duration.into() {
Some(duration) => tokio::time::timeout(duration, self.send(req))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity),
None => self.send(req).await,
}
}
/// Sends a request without waiting for a response; this method is able to be used even
/// if the session's receiving line to the remote server has been severed
pub async fn fire(&mut self, req: impl Into<Request<T>>) -> io::Result<()> {
self.inner.fire(req.into().to_untyped_request()?).await
}
/// Sends a request without waiting for a response, timing out after duration has passed
pub async fn fire_timeout(
&mut self,
req: impl Into<Request<T>>,
duration: impl Into<Option<Duration>>,
) -> io::Result<()> {
match duration.into() {
Some(duration) => tokio::time::timeout(duration, self.fire(req))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity),
None => self.fire(req).await,
}
}
}
fn map_to_typed_mailbox<T: Send + DeserializeOwned + 'static>(
mailbox: Mailbox<UntypedResponse<'static>>,
) -> Mailbox<Response<T>> {
mailbox.map_opt(|res| match res.to_typed_response() {
Ok(res) => Some(res),
Err(x) => {
if log::log_enabled!(Level::Trace) {
trace!(
"Invalid response payload: {}",
String::from_utf8_lossy(&res.payload)
);
}
error!(
"Unable to parse response payload into {}: {x}",
std::any::type_name::<T>()
);
None
}
})
}
/// Represents a sender of requests tied to a session, holding onto a weak reference of
/// mailboxes to relay responses, meaning that once the [`Client`] is closed or dropped,
/// any sent request will no longer be able to receive responses.
///
/// In contrast to [`Channel`], this implementation is untyped, meaning that the payload of
/// requests and responses are not validated.
///
/// [`Client`]: crate::client::Client
#[derive(Debug)]
pub struct UntypedChannel {
/// Used to send requests to a server
pub(crate) tx: mpsc::Sender<Request<T>>,
pub(crate) tx: mpsc::Sender<UntypedRequest<'static>>,
/// Collection of mailboxes for receiving responses to requests
pub(crate) post_office: Weak<PostOffice<Response<U>>>,
pub(crate) post_office: Weak<PostOffice<UntypedResponse<'static>>>,
}
// NOTE: Implemented manually to avoid needing clone to be defined on generic types
impl<T, U> Clone for Channel<T, U>
where
T: Send + Sync,
U: Send + Sync,
{
impl Clone for UntypedChannel {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
@ -37,31 +194,66 @@ where
}
}
impl<T, U> Channel<T, U>
where
T: Send + Sync,
U: Send + Sync + 'static,
{
impl UntypedChannel {
/// Returns true if no more requests can be transferred
pub fn is_closed(&self) -> bool {
self.tx.is_closed()
}
/// Consumes this channel, returning a typed variant
pub fn into_typed_channel<T, U>(self) -> Channel<T, U> {
Channel {
inner: self,
_request: PhantomData,
_response: PhantomData,
}
}
/// Assigns a default mailbox for any response received that does not match another mailbox.
pub async fn assign_default_mailbox(
&self,
buffer: usize,
) -> io::Result<Mailbox<UntypedResponse<'static>>> {
match Weak::upgrade(&self.post_office) {
Some(post_office) => Ok(post_office.assign_default_mailbox(buffer).await),
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"Channel's post office is no longer available",
)),
}
}
/// Removes the default mailbox used for unmatched responses such that any response without a
/// matching mailbox will be dropped.
pub async fn remove_default_mailbox(&self) -> io::Result<()> {
match Weak::upgrade(&self.post_office) {
Some(post_office) => {
post_office.remove_default_mailbox().await;
Ok(())
}
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"Channel's post office is no longer available",
)),
}
}
/// Sends a request and returns a mailbox that can receive one or more responses, failing if
/// unable to send a request or if the session's receiving line to the remote server has
/// already been severed
pub async fn mail(&mut self, req: impl Into<Request<T>>) -> io::Result<Mailbox<Response<U>>> {
let req = req.into();
pub async fn mail(
&mut self,
req: UntypedRequest<'_>,
) -> io::Result<Mailbox<UntypedResponse<'static>>> {
// First, create a mailbox using the request's id
let mailbox = Weak::upgrade(&self.post_office)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotConnected,
"Session's post office is no longer available",
"Channel's post office is no longer available",
)
})?
.make_mailbox(req.id.clone(), CHANNEL_MAILBOX_CAPACITY)
.make_mailbox(req.id.clone().into_owned(), CHANNEL_MAILBOX_CAPACITY)
.await;
// Second, send the request
@ -74,9 +266,9 @@ where
/// Sends a request and returns a mailbox, timing out after duration has passed
pub async fn mail_timeout(
&mut self,
req: impl Into<Request<T>>,
req: UntypedRequest<'_>,
duration: impl Into<Option<Duration>>,
) -> io::Result<Mailbox<Response<U>>> {
) -> io::Result<Mailbox<UntypedResponse<'static>>> {
match duration.into() {
Some(duration) => tokio::time::timeout(duration, self.mail(req))
.await
@ -88,7 +280,7 @@ where
/// Sends a request and waits for a response, failing if unable to send a request or if
/// the session's receiving line to the remote server has already been severed
pub async fn send(&mut self, req: impl Into<Request<T>>) -> io::Result<Response<U>> {
pub async fn send(&mut self, req: UntypedRequest<'_>) -> io::Result<UntypedResponse<'static>> {
// Send mail and get back a mailbox
let mut mailbox = self.mail(req).await?;
@ -102,9 +294,9 @@ where
/// Sends a request and waits for a response, timing out after duration has passed
pub async fn send_timeout(
&mut self,
req: impl Into<Request<T>>,
req: UntypedRequest<'_>,
duration: impl Into<Option<Duration>>,
) -> io::Result<Response<U>> {
) -> io::Result<UntypedResponse<'static>> {
match duration.into() {
Some(duration) => tokio::time::timeout(duration, self.send(req))
.await
@ -116,9 +308,9 @@ where
/// Sends a request without waiting for a response; this method is able to be used even
/// if the session's receiving line to the remote server has been severed
pub async fn fire(&mut self, req: impl Into<Request<T>>) -> io::Result<()> {
pub async fn fire(&mut self, req: UntypedRequest<'_>) -> io::Result<()> {
self.tx
.send(req.into())
.send(req.into_owned())
.await
.map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x.to_string()))
}
@ -126,7 +318,7 @@ where
/// Sends a request without waiting for a response, timing out after duration has passed
pub async fn fire_timeout(
&mut self,
req: impl Into<Request<T>>,
req: UntypedRequest<'_>,
duration: impl Into<Option<Duration>>,
) -> io::Result<()> {
match duration.into() {
@ -142,95 +334,227 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{Client, FramedTransport, TypedAsyncRead, TypedAsyncWrite};
use std::time::Duration;
type TestClient = Client<u8, u8>;
mod typed {
use super::*;
use std::sync::Arc;
use std::time::Duration;
use test_log::test;
type TestChannel = Channel<u8, u8>;
type Setup = (
TestChannel,
mpsc::Receiver<UntypedRequest<'static>>,
Arc<PostOffice<UntypedResponse<'static>>>,
);
fn setup(buffer: usize) -> Setup {
let post_office = Arc::new(PostOffice::default());
let (tx, rx) = mpsc::channel(buffer);
let channel = {
let post_office = Arc::downgrade(&post_office);
UntypedChannel { tx, post_office }
};
(channel.into_typed_channel(), rx, post_office)
}
#[tokio::test]
async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() {
let (t1, mut t2) = FramedTransport::make_test_pair();
let session: TestClient = Client::from_framed_transport(t1).unwrap();
let mut channel = session.clone_channel();
#[test(tokio::test)]
async fn mail_should_return_mailbox_that_receives_responses_until_post_office_drops_it() {
let (mut channel, _server, post_office) = setup(100);
let req = Request::new(0);
let res = Response::new(req.id.clone(), 1);
let mut mailbox = channel.mail(req).await.unwrap();
// Get first response
match tokio::join!(mailbox.next(), t2.write(res.clone())) {
(Some(actual), _) => assert_eq!(actual, res),
// Send and receive first response
assert!(
post_office
.deliver_untyped_response(res.to_untyped_response().unwrap().into_owned())
.await,
"Failed to deliver: {res:?}"
);
assert_eq!(mailbox.next().await, Some(res.clone()));
// Send and receive second response
assert!(
post_office
.deliver_untyped_response(res.to_untyped_response().unwrap().into_owned())
.await,
"Failed to deliver: {res:?}"
);
assert_eq!(mailbox.next().await, Some(res.clone()));
// Trigger the mailbox to wait BEFORE closing our mailbox to ensure that
// we don't get stuck if the mailbox was already waiting
let next_task = tokio::spawn(async move { mailbox.next().await });
tokio::task::yield_now().await;
// Close our specific mailbox
post_office.cancel(&res.origin_id).await;
match next_task.await {
Ok(None) => {}
x => panic!("Unexpected response: {:?}", x),
}
}
#[test(tokio::test)]
async fn send_should_wait_until_response_received() {
let (mut channel, _server, post_office) = setup(100);
let req = Request::new(0);
let res = Response::new(req.id.clone(), 1);
let (actual, _) = tokio::join!(
channel.send(req),
post_office
.deliver_untyped_response(res.to_untyped_response().unwrap().into_owned())
);
match actual {
Ok(actual) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
}
}
#[test(tokio::test)]
async fn send_timeout_should_fail_if_response_not_received_in_time() {
let (mut channel, mut server, _post_office) = setup(100);
let req = Request::new(0);
match channel.send_timeout(req, Duration::from_millis(30)).await {
Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
x => panic!("Unexpected response: {:?}", x),
}
// Get second response
match tokio::join!(mailbox.next(), t2.write(res.clone())) {
(Some(actual), _) => assert_eq!(actual, res),
let _frame = server.recv().await.unwrap();
}
#[test(tokio::test)]
async fn fire_should_send_request_and_not_wait_for_response() {
let (mut channel, mut server, _post_office) = setup(100);
let req = Request::new(0);
match channel.fire(req).await {
Ok(_) => {}
x => panic!("Unexpected response: {:?}", x),
}
// Trigger the mailbox to wait BEFORE closing our transport to ensure that
let _frame = server.recv().await.unwrap();
}
}
mod untyped {
use super::*;
use std::sync::Arc;
use std::time::Duration;
use test_log::test;
type TestChannel = UntypedChannel;
type Setup = (
TestChannel,
mpsc::Receiver<UntypedRequest<'static>>,
Arc<PostOffice<UntypedResponse<'static>>>,
);
fn setup(buffer: usize) -> Setup {
let post_office = Arc::new(PostOffice::default());
let (tx, rx) = mpsc::channel(buffer);
let channel = {
let post_office = Arc::downgrade(&post_office);
TestChannel { tx, post_office }
};
(channel, rx, post_office)
}
#[test(tokio::test)]
async fn mail_should_return_mailbox_that_receives_responses_until_post_office_drops_it() {
let (mut channel, _server, post_office) = setup(100);
let req = Request::new(0).to_untyped_request().unwrap().into_owned();
let res = Response::new(req.id.clone().into_owned(), 1)
.to_untyped_response()
.unwrap()
.into_owned();
let mut mailbox = channel.mail(req).await.unwrap();
// Send and receive first response
assert!(
post_office.deliver_untyped_response(res.clone()).await,
"Failed to deliver: {res:?}"
);
assert_eq!(mailbox.next().await, Some(res.clone()));
// Send and receive second response
assert!(
post_office.deliver_untyped_response(res.clone()).await,
"Failed to deliver: {res:?}"
);
assert_eq!(mailbox.next().await, Some(res.clone()));
// Trigger the mailbox to wait BEFORE closing our mailbox to ensure that
// we don't get stuck if the mailbox was already waiting
let next_task = tokio::spawn(async move { mailbox.next().await });
tokio::task::yield_now().await;
drop(t2);
// Close our specific mailbox
post_office
.cancel(&res.origin_id.clone().into_owned())
.await;
match next_task.await {
Ok(None) => {}
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
#[test(tokio::test)]
async fn send_should_wait_until_response_received() {
let (t1, mut t2) = FramedTransport::make_test_pair();
let session: TestClient = Client::from_framed_transport(t1).unwrap();
let mut channel = session.clone_channel();
let (mut channel, _server, post_office) = setup(100);
let req = Request::new(0);
let res = Response::new(req.id.clone(), 1);
let req = Request::new(0).to_untyped_request().unwrap().into_owned();
let res = Response::new(req.id.clone().into_owned(), 1)
.to_untyped_response()
.unwrap()
.into_owned();
let (actual, _) = tokio::join!(channel.send(req), t2.write(res.clone()));
let (actual, _) = tokio::join!(
channel.send(req),
post_office.deliver_untyped_response(res.clone())
);
match actual {
Ok(actual) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
#[test(tokio::test)]
async fn send_timeout_should_fail_if_response_not_received_in_time() {
let (t1, mut t2) = FramedTransport::make_test_pair();
let session: TestClient = Client::from_framed_transport(t1).unwrap();
let mut channel = session.clone_channel();
let (mut channel, mut server, _post_office) = setup(100);
let req = Request::new(0);
let req = Request::new(0).to_untyped_request().unwrap().into_owned();
match channel.send_timeout(req, Duration::from_millis(30)).await {
Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
x => panic!("Unexpected response: {:?}", x),
}
let _req = TypedAsyncRead::<Request<u8>>::read(&mut t2)
.await
.unwrap()
.unwrap();
let _frame = server.recv().await.unwrap();
}
#[tokio::test]
#[test(tokio::test)]
async fn fire_should_send_request_and_not_wait_for_response() {
let (t1, mut t2) = FramedTransport::make_test_pair();
let session: TestClient = Client::from_framed_transport(t1).unwrap();
let mut channel = session.clone_channel();
let (mut channel, mut server, _post_office) = setup(100);
let req = Request::new(0);
let req = Request::new(0).to_untyped_request().unwrap().into_owned();
match channel.fire(req).await {
Ok(_) => {}
x => panic!("Unexpected response: {:?}", x),
}
let _req = TypedAsyncRead::<Request<u8>>::read(&mut t2)
.await
.unwrap()
.unwrap();
let _frame = server.recv().await.unwrap();
}
}
}

@ -1,4 +1,5 @@
use crate::{Id, Response};
use crate::common::{Id, Response, UntypedResponse};
use async_trait::async_trait;
use std::{
collections::HashMap,
sync::{Arc, Weak},
@ -6,13 +7,14 @@ use std::{
};
use tokio::{
io,
sync::{mpsc, Mutex},
sync::{mpsc, Mutex, RwLock},
time,
};
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct PostOffice<T> {
mailboxes: Arc<Mutex<HashMap<Id, mpsc::Sender<T>>>>,
default_box: Arc<RwLock<Option<mpsc::Sender<T>>>>,
}
impl<T> Default for PostOffice<T>
@ -51,7 +53,10 @@ where
}
});
Self { mailboxes }
Self {
mailboxes,
default_box: Arc::new(RwLock::new(None)),
}
}
/// Creates a new mailbox using the given id and buffer size for maximum values that
@ -60,7 +65,10 @@ where
let (tx, rx) = mpsc::channel(buffer);
self.mailboxes.lock().await.insert(id.clone(), tx);
Mailbox { id, rx }
Mailbox {
id,
rx: Box::new(rx),
}
}
/// Delivers some value to appropriate mailbox, returning false if no mailbox is found
@ -75,10 +83,54 @@ where
}
success
} else if let Some(tx) = self.default_box.read().await.as_ref() {
tx.send(value).await.is_ok()
} else {
false
}
}
/// Creates a new default mailbox that will be used whenever no mailbox is found to deliver
/// mail. This will replace any existing default mailbox.
pub async fn assign_default_mailbox(&self, buffer: usize) -> Mailbox<T> {
let (tx, rx) = mpsc::channel(buffer);
*self.default_box.write().await = Some(tx);
Mailbox {
id: "".to_string(),
rx: Box::new(rx),
}
}
/// Removes the default mailbox such that any mail without a matching mailbox will be dropped
/// instead of being delivered to a default mailbox.
pub async fn remove_default_mailbox(&self) {
*self.default_box.write().await = None;
}
/// Returns true if the post office is using a default mailbox for all mail that does not map
/// to another mailbox.
pub async fn has_default_mailbox(&self) -> bool {
self.default_box.read().await.is_some()
}
/// Cancels delivery to the mailbox with the specified `id`.
pub async fn cancel(&self, id: &Id) {
self.mailboxes.lock().await.remove(id);
}
/// Cancels delivery to the mailboxes with the specified `id`s.
pub async fn cancel_many(&self, ids: impl Iterator<Item = &Id>) {
let mut lock = self.mailboxes.lock().await;
for id in ids {
lock.remove(id);
}
}
/// Cancels delivery to all mailboxes.
pub async fn cancel_all(&self) {
self.mailboxes.lock().await.clear();
}
}
impl<T> PostOffice<Response<T>>
@ -92,13 +144,120 @@ where
}
}
impl PostOffice<UntypedResponse<'static>> {
/// Delivers some response to appropriate mailbox, returning false if no mailbox is found
/// for the response's origin or if the mailbox is no longer receiving values
pub async fn deliver_untyped_response(&self, res: UntypedResponse<'static>) -> bool {
self.deliver(&res.origin_id.clone().into_owned(), res).await
}
}
/// Error encountered when invoking [`try_recv`] for [`MailboxReceiver`].
pub enum MailboxTryNextError {
Empty,
Closed,
}
#[async_trait]
trait MailboxReceiver: Send + Sync {
type Output;
fn try_recv(&mut self) -> Result<Self::Output, MailboxTryNextError>;
async fn recv(&mut self) -> Option<Self::Output>;
fn close(&mut self);
}
#[async_trait]
impl<T: Send> MailboxReceiver for mpsc::Receiver<T> {
type Output = T;
fn try_recv(&mut self) -> Result<Self::Output, MailboxTryNextError> {
match mpsc::Receiver::try_recv(self) {
Ok(x) => Ok(x),
Err(mpsc::error::TryRecvError::Empty) => Err(MailboxTryNextError::Empty),
Err(mpsc::error::TryRecvError::Disconnected) => Err(MailboxTryNextError::Closed),
}
}
async fn recv(&mut self) -> Option<Self::Output> {
mpsc::Receiver::recv(self).await
}
fn close(&mut self) {
mpsc::Receiver::close(self)
}
}
struct MappedMailboxReceiver<T, U> {
rx: Box<dyn MailboxReceiver<Output = T>>,
f: Box<dyn Fn(T) -> U + Send + Sync>,
}
#[async_trait]
impl<T: Send, U: Send> MailboxReceiver for MappedMailboxReceiver<T, U> {
type Output = U;
fn try_recv(&mut self) -> Result<Self::Output, MailboxTryNextError> {
match self.rx.try_recv() {
Ok(x) => Ok((self.f)(x)),
Err(x) => Err(x),
}
}
async fn recv(&mut self) -> Option<Self::Output> {
let value = self.rx.recv().await?;
Some((self.f)(value))
}
fn close(&mut self) {
self.rx.close()
}
}
struct MappedOptMailboxReceiver<T, U> {
rx: Box<dyn MailboxReceiver<Output = T>>,
f: Box<dyn Fn(T) -> Option<U> + Send + Sync>,
}
#[async_trait]
impl<T: Send, U: Send> MailboxReceiver for MappedOptMailboxReceiver<T, U> {
type Output = U;
fn try_recv(&mut self) -> Result<Self::Output, MailboxTryNextError> {
match self.rx.try_recv() {
Ok(x) => match (self.f)(x) {
Some(x) => Ok(x),
None => Err(MailboxTryNextError::Empty),
},
Err(x) => Err(x),
}
}
async fn recv(&mut self) -> Option<Self::Output> {
// Continually receive a new value and convert it to Option<U>
// until Option<U> == Some(U) or we receive None from our inner receiver
loop {
let value = self.rx.recv().await?;
if let Some(x) = (self.f)(value) {
return Some(x);
}
}
}
fn close(&mut self) {
self.rx.close()
}
}
/// Represents a destination for responses
pub struct Mailbox<T> {
/// Represents id associated with the mailbox
id: Id,
/// Underlying mailbox storage
rx: mpsc::Receiver<T>,
rx: Box<dyn MailboxReceiver<Output = T>>,
}
impl<T> Mailbox<T> {
@ -107,6 +266,11 @@ impl<T> Mailbox<T> {
&self.id
}
/// Tries to receive the next value in mailbox without blocking or waiting async
pub fn try_next(&mut self) -> Result<T, MailboxTryNextError> {
self.rx.try_recv()
}
/// Receives next value in mailbox
pub async fn next(&mut self) -> Option<T> {
self.rx.recv().await
@ -126,3 +290,31 @@ impl<T> Mailbox<T> {
self.rx.close()
}
}
impl<T: Send + 'static> Mailbox<T> {
/// Maps the results of each mailbox value into a new type `U`
pub fn map<U: Send + 'static>(self, f: impl Fn(T) -> U + Send + Sync + 'static) -> Mailbox<U> {
Mailbox {
id: self.id,
rx: Box::new(MappedMailboxReceiver {
rx: self.rx,
f: Box::new(f),
}),
}
}
/// Maps the results of each mailbox value into a new type `U` by returning an `Option<U>`
/// where the option is `None` in the case that `T` cannot be converted into `U`
pub fn map_opt<U: Send + 'static>(
self,
f: impl Fn(T) -> Option<U> + Send + Sync + 'static,
) -> Mailbox<U> {
Mailbox {
id: self.id,
rx: Box::new(MappedOptMailboxReceiver {
rx: self.rx,
f: Box::new(f),
}),
}
}
}

@ -1,49 +0,0 @@
use crate::{Client, Codec, FramedTransport, TcpTransport};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, net::SocketAddr};
use tokio::{io, time::Duration};
#[async_trait]
pub trait TcpClientExt<T, U>
where
T: Serialize + Send + Sync,
U: DeserializeOwned + Send + Sync,
{
/// Connect to a remote TCP server using the provided information
async fn connect<C>(addr: SocketAddr, codec: C) -> io::Result<Client<T, U>>
where
C: Codec + Send + 'static;
/// Connect to a remote TCP server, timing out after duration has passed
async fn connect_timeout<C>(
addr: SocketAddr,
codec: C,
duration: Duration,
) -> io::Result<Client<T, U>>
where
C: Codec + Send + 'static,
{
tokio::time::timeout(duration, Self::connect(addr, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
}
#[async_trait]
impl<T, U> TcpClientExt<T, U> for Client<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
/// Connect to a remote TCP server using the provided information
async fn connect<C>(addr: SocketAddr, codec: C) -> io::Result<Client<T, U>>
where
C: Codec + Send + 'static,
{
let transport = TcpTransport::connect(addr).await?;
let transport = FramedTransport::new(transport, codec);
Self::from_framed_transport(transport)
}
}

@ -1,54 +0,0 @@
use crate::{Client, Codec, FramedTransport, IntoSplit, UnixSocketTransport};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, path::Path};
use tokio::{io, time::Duration};
#[async_trait]
pub trait UnixSocketClientExt<T, U>
where
T: Serialize + Send + Sync,
U: DeserializeOwned + Send + Sync,
{
/// Connect to a proxy unix socket
async fn connect<P, C>(path: P, codec: C) -> io::Result<Client<T, U>>
where
P: AsRef<Path> + Send,
C: Codec + Send + 'static;
/// Connect to a proxy unix socket, timing out after duration has passed
async fn connect_timeout<P, C>(
path: P,
codec: C,
duration: Duration,
) -> io::Result<Client<T, U>>
where
P: AsRef<Path> + Send,
C: Codec + Send + 'static,
{
tokio::time::timeout(duration, Self::connect(path, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
}
#[async_trait]
impl<T, U> UnixSocketClientExt<T, U> for Client<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
/// Connect to a proxy unix socket
async fn connect<P, C>(path: P, codec: C) -> io::Result<Client<T, U>>
where
P: AsRef<Path> + Send,
C: Codec + Send + 'static,
{
let p = path.as_ref();
let transport = UnixSocketTransport::connect(p).await?;
let transport = FramedTransport::new(transport, codec);
let (writer, reader) = transport.into_split();
Ok(Client::new(writer, reader)?)
}
}

@ -1,86 +0,0 @@
use crate::{Client, Codec, FramedTransport, IntoSplit, WindowsPipeTransport};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{
convert,
ffi::{OsStr, OsString},
};
use tokio::{io, time::Duration};
#[async_trait]
pub trait WindowsPipeClientExt<T, U>
where
T: Serialize + Send + Sync,
U: DeserializeOwned + Send + Sync,
{
/// Connect to a server listening on a Windows pipe at the specified address
/// using the given codec
async fn connect<A, C>(addr: A, codec: C) -> io::Result<Client<T, U>>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + 'static;
/// Connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}` using the given codec
async fn connect_local<N, C>(name: N, codec: C) -> io::Result<Client<T, U>>
where
N: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::connect(addr, codec).await
}
/// Connect to a server listening on a Windows pipe at the specified address
/// using the given codec, timing out after duration has passed
async fn connect_timeout<A, C>(
addr: A,
codec: C,
duration: Duration,
) -> io::Result<Client<T, U>>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
tokio::time::timeout(duration, Self::connect(addr, codec))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
/// Connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed
async fn connect_local_timeout<N, C>(
name: N,
codec: C,
duration: Duration,
) -> io::Result<Client<T, U>>
where
N: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::connect_timeout(addr, codec, duration).await
}
}
#[async_trait]
impl<T, U> WindowsPipeClientExt<T, U> for Client<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
async fn connect<A, C>(addr: A, codec: C) -> io::Result<Client<T, U>>
where
A: AsRef<OsStr> + Send,
C: Codec + Send + 'static,
{
let a = addr.as_ref();
let transport = WindowsPipeTransport::connect(a).await?;
let transport = FramedTransport::new(transport, codec);
let (writer, reader) = transport.into_split();
Ok(Client::new(writer, reader)?)
}
}

@ -0,0 +1,208 @@
use super::Reconnectable;
use std::io;
use std::time::Duration;
/// Represents the strategy to apply when attempting to reconnect the client to the server.
#[derive(Clone, Debug)]
pub enum ReconnectStrategy {
/// A retry strategy that will fail immediately if a reconnect is attempted.
Fail,
/// A retry strategy driven by exponential back-off.
ExponentialBackoff {
/// Represents the initial time to wait between reconnect attempts.
base: Duration,
/// Factor to use when modifying the retry time, used as a multiplier.
factor: f64,
/// Represents the maximum duration to wait between attempts. None indicates no limit.
max_duration: Option<Duration>,
/// Represents the maximum attempts to retry before failing. None indicates no limit.
max_retries: Option<usize>,
/// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
timeout: Option<Duration>,
},
/// A retry strategy driven by the fibonacci series.
FibonacciBackoff {
/// Represents the initial time to wait between reconnect attempts.
base: Duration,
/// Represents the maximum duration to wait between attempts. None indicates no limit.
max_duration: Option<Duration>,
/// Represents the maximum attempts to retry before failing. None indicates no limit.
max_retries: Option<usize>,
/// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
timeout: Option<Duration>,
},
/// A retry strategy driven by a fixed interval.
FixedInterval {
/// Represents the time between reconnect attempts.
interval: Duration,
/// Represents the maximum attempts to retry before failing. None indicates no limit.
max_retries: Option<usize>,
/// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
timeout: Option<Duration>,
},
}
impl Default for ReconnectStrategy {
/// Creates a reconnect strategy that will immediately fail.
fn default() -> Self {
Self::Fail
}
}
impl ReconnectStrategy {
pub async fn reconnect<T: Reconnectable>(&mut self, reconnectable: &mut T) -> io::Result<()> {
// If our strategy is to immediately fail, do so
if self.is_fail() {
return Err(io::Error::from(io::ErrorKind::ConnectionAborted));
}
// Keep track of last sleep length for use in adjustment
let mut previous_sleep = None;
let mut current_sleep = self.initial_sleep_duration();
// Keep track of remaining retries
let mut retries_remaining = self.max_retries();
// Get timeout if strategy will employ one
let timeout = self.timeout();
// Get maximum allowed duration between attempts
let max_duration = self.max_duration();
// Continue trying to reconnect while we have more tries remaining, otherwise
// we will return the last error encountered
let mut result = Ok(());
while retries_remaining.is_none() || retries_remaining > Some(0) {
// Perform reconnect attempt
result = match timeout {
Some(timeout) => {
match tokio::time::timeout(timeout, reconnectable.reconnect()).await {
Ok(x) => x,
Err(x) => Err(x.into()),
}
}
None => reconnectable.reconnect().await,
};
// If reconnect was successful, we're done and we can exit
if result.is_ok() {
return Ok(());
}
// Decrement remaining retries if we have a limit
if let Some(remaining) = retries_remaining.as_mut() {
if *remaining > 0 {
*remaining -= 1;
}
}
// Sleep before making next attempt
tokio::time::sleep(current_sleep).await;
// Update our sleep duration
let next_sleep = self.adjust_sleep(previous_sleep, current_sleep);
previous_sleep = Some(current_sleep);
current_sleep = if let Some(duration) = max_duration {
std::cmp::min(next_sleep, duration)
} else {
next_sleep
};
}
result
}
/// Returns true if this strategy is the fail variant.
pub fn is_fail(&self) -> bool {
matches!(self, Self::Fail)
}
/// Returns true if this strategy is the exponential backoff variant.
pub fn is_exponential_backoff(&self) -> bool {
matches!(self, Self::ExponentialBackoff { .. })
}
/// Returns true if this strategy is the fibonacci backoff variant.
pub fn is_fibonacci_backoff(&self) -> bool {
matches!(self, Self::FibonacciBackoff { .. })
}
/// Returns true if this strategy is the fixed interval variant.
pub fn is_fixed_interval(&self) -> bool {
matches!(self, Self::FixedInterval { .. })
}
/// Returns the maximum duration between reconnect attempts, or None if there is no limit.
pub fn max_duration(&self) -> Option<Duration> {
match self {
ReconnectStrategy::Fail => None,
ReconnectStrategy::ExponentialBackoff { max_duration, .. } => *max_duration,
ReconnectStrategy::FibonacciBackoff { max_duration, .. } => *max_duration,
ReconnectStrategy::FixedInterval { .. } => None,
}
}
/// Returns the maximum reconnect attempts the strategy will perform, or None if will attempt
/// forever.
pub fn max_retries(&self) -> Option<usize> {
match self {
ReconnectStrategy::Fail => None,
ReconnectStrategy::ExponentialBackoff { max_retries, .. } => *max_retries,
ReconnectStrategy::FibonacciBackoff { max_retries, .. } => *max_retries,
ReconnectStrategy::FixedInterval { max_retries, .. } => *max_retries,
}
}
/// Returns the timeout per reconnect attempt that is associated with the strategy.
pub fn timeout(&self) -> Option<Duration> {
match self {
ReconnectStrategy::Fail => None,
ReconnectStrategy::ExponentialBackoff { timeout, .. } => *timeout,
ReconnectStrategy::FibonacciBackoff { timeout, .. } => *timeout,
ReconnectStrategy::FixedInterval { timeout, .. } => *timeout,
}
}
/// Returns the initial duration to sleep.
fn initial_sleep_duration(&self) -> Duration {
match self {
ReconnectStrategy::Fail => Duration::new(0, 0),
ReconnectStrategy::ExponentialBackoff { base, .. } => *base,
ReconnectStrategy::FibonacciBackoff { base, .. } => *base,
ReconnectStrategy::FixedInterval { interval, .. } => *interval,
}
}
/// Adjusts next sleep duration based on the strategy.
fn adjust_sleep(&self, prev: Option<Duration>, curr: Duration) -> Duration {
match self {
ReconnectStrategy::Fail => Duration::new(0, 0),
ReconnectStrategy::ExponentialBackoff { factor, .. } => {
let next_millis = (curr.as_millis() as f64) * factor;
Duration::from_millis(if next_millis > (std::u64::MAX as f64) {
std::u64::MAX
} else {
next_millis as u64
})
}
ReconnectStrategy::FibonacciBackoff { .. } => {
let prev = prev.unwrap_or_else(|| Duration::new(0, 0));
prev.checked_add(curr).unwrap_or(Duration::MAX)
}
ReconnectStrategy::FixedInterval { .. } => curr,
}
}
}

@ -0,0 +1,36 @@
use async_trait::async_trait;
use dyn_clone::DynClone;
use std::io;
use tokio::sync::{mpsc, oneshot};
/// Interface representing functionality to shut down an active client.
#[async_trait]
pub trait Shutdown: DynClone + Send + Sync {
/// Attempts to shutdown the client.
async fn shutdown(&self) -> io::Result<()>;
}
#[async_trait]
impl Shutdown for mpsc::Sender<oneshot::Sender<io::Result<()>>> {
async fn shutdown(&self) -> io::Result<()> {
let (tx, rx) = oneshot::channel();
match self.send(tx).await {
Ok(_) => match rx.await {
Ok(x) => x,
Err(_) => Err(already_shutdown()),
},
Err(_) => Err(already_shutdown()),
}
}
}
#[inline]
fn already_shutdown() -> io::Error {
io::Error::new(io::ErrorKind::Other, "Client already shutdown")
}
impl Clone for Box<dyn Shutdown> {
fn clone(&self) -> Self {
dyn_clone::clone_box(&**self)
}
}

@ -1,38 +0,0 @@
use bytes::BytesMut;
use std::io;
use tokio_util::codec::{Decoder, Encoder};
/// Represents abstraction of a codec that implements specific encoder and decoder for distant
pub trait Codec:
for<'a> Encoder<&'a [u8], Error = io::Error> + Decoder<Item = Vec<u8>, Error = io::Error> + Clone
{
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()>;
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>>;
}
macro_rules! impl_traits_for_codec {
($type:ident) => {
impl<'a> tokio_util::codec::Encoder<&'a [u8]> for $type {
type Error = io::Error;
fn encode(&mut self, item: &'a [u8], dst: &mut BytesMut) -> Result<(), Self::Error> {
Codec::encode(self, item, dst)
}
}
impl tokio_util::codec::Decoder for $type {
type Item = Vec<u8>;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
Codec::decode(self, src)
}
}
};
}
mod plain;
pub use plain::PlainCodec;
mod xchacha20poly1305;
pub use xchacha20poly1305::XChaCha20Poly1305Codec;

@ -1,193 +0,0 @@
use crate::Codec;
use bytes::{Buf, BufMut, BytesMut};
use std::convert::TryInto;
use tokio::io;
/// Total bytes to use as the len field denoting a frame's size
const LEN_SIZE: usize = 8;
/// Represents a codec that just ships messages back and forth with no encryption or authentication
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct PlainCodec;
impl_traits_for_codec!(PlainCodec);
impl PlainCodec {
pub fn new() -> Self {
Self::default()
}
}
impl Codec for PlainCodec {
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
// Validate that we can fit the message plus nonce +
if item.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Empty item provided",
));
}
dst.reserve(8 + item.len());
// Add data in form of {LEN}{ITEM}
dst.put_u64((item.len()) as u64);
dst.put_slice(item);
Ok(())
}
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
// First, check if we have more data than just our frame's message length
if src.len() <= LEN_SIZE {
return Ok(None);
}
// Second, retrieve total size of our frame's message
let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap()) as usize;
if msg_len == 0 {
// Ensure we advance to remove the frame
src.advance(LEN_SIZE);
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Frame's msg cannot have length of 0",
));
}
// Third, check if we have all data for our frame; if not, exit early
if src.len() < msg_len + LEN_SIZE {
return Ok(None);
}
// Fourth, get and return our item
let item = src[LEN_SIZE..(LEN_SIZE + msg_len)].to_vec();
// Fifth, advance so frame is no longer kept around
src.advance(LEN_SIZE + msg_len);
Ok(Some(item))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_should_fail_when_item_is_zero_bytes() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
let result = codec.encode(&[], &mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidInput => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn encode_should_build_a_frame_containing_a_length_and_item() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
let len = buf.get_u64() as usize;
assert_eq!(len, 12, "Wrong length encoded");
assert_eq!(buf.as_ref(), b"hello, world");
}
#[test]
fn decode_should_return_none_if_data_smaller_than_or_equal_to_item_length_field() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_bytes(0, LEN_SIZE);
let result = codec.decode(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn decode_should_return_none_if_not_enough_data_for_frame() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_u64(0);
let result = codec.decode(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn decode_should_fail_if_encoded_item_length_is_zero() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_u64(0);
buf.put_u8(255);
let result = codec.decode(&mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn decode_should_advance_src_by_frame_size_even_if_item_length_is_zero() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
buf.put_u64(0);
buf.put_bytes(0, 3);
assert!(
codec.decode(&mut buf).is_err(),
"Decode unexpectedly succeeded"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn decode_should_advance_src_by_frame_size_when_successful() {
let mut codec = PlainCodec::new();
// Add 3 extra bytes after a full frame
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
buf.put_bytes(0, 3);
assert!(codec.decode(&mut buf).is_ok(), "Decode unexpectedly failed");
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn decode_should_return_some_byte_vec_when_successful() {
let mut codec = PlainCodec::new();
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
let item = codec
.decode(&mut buf)
.expect("Failed to decode")
.expect("Item not properly captured");
assert_eq!(item, b"hello, world");
}
}

@ -1,269 +0,0 @@
use crate::{Codec, SecretKey, SecretKey32};
use bytes::{Buf, BufMut, BytesMut};
use chacha20poly1305::{aead::Aead, Key, KeyInit, XChaCha20Poly1305, XNonce};
use std::{convert::TryInto, fmt};
use tokio::io;
/// Total bytes to use as the len field denoting a frame's size
const LEN_SIZE: usize = 8;
/// Total bytes to use for nonce
const NONCE_SIZE: usize = 24;
/// Represents the codec to encode & decode data while also encrypting/decrypting it
///
/// Uses a 32-byte key internally
#[derive(Clone)]
pub struct XChaCha20Poly1305Codec {
cipher: XChaCha20Poly1305,
}
impl_traits_for_codec!(XChaCha20Poly1305Codec);
impl XChaCha20Poly1305Codec {
pub fn new(key: &[u8]) -> Self {
let key = Key::from_slice(key);
let cipher = XChaCha20Poly1305::new(key);
Self { cipher }
}
}
impl From<SecretKey32> for XChaCha20Poly1305Codec {
/// Create a new XChaCha20Poly1305 codec with a 32-byte key
fn from(secret_key: SecretKey32) -> Self {
Self::new(secret_key.unprotected_as_bytes())
}
}
impl fmt::Debug for XChaCha20Poly1305Codec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("XChaCha20Poly1305Codec")
.field("cipher", &"**OMITTED**".to_string())
.finish()
}
}
impl Codec for XChaCha20Poly1305Codec {
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
// Validate that we can fit the message plus nonce +
if item.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Empty item provided",
));
}
// NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of
// maintaining a stateful counter due to its size (24-byte secret key generation
// will never panic)
let nonce_key = SecretKey::<NONCE_SIZE>::generate().unwrap();
let nonce = XNonce::from_slice(nonce_key.unprotected_as_bytes());
let ciphertext = self
.cipher
.encrypt(nonce, item)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
dst.reserve(8 + nonce.len() + ciphertext.len());
// Add data in form of {LEN}{NONCE}{CIPHER TEXT}
dst.put_u64((nonce_key.len() + ciphertext.len()) as u64);
dst.put_slice(nonce.as_slice());
dst.extend(ciphertext);
Ok(())
}
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
// First, check if we have more data than just our frame's message length
if src.len() <= LEN_SIZE {
return Ok(None);
}
// Second, retrieve total size of our frame's message
let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap()) as usize;
if msg_len <= NONCE_SIZE {
// Ensure we advance to remove the frame
src.advance(LEN_SIZE + msg_len);
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Frame's msg cannot have length less than 25",
));
}
// Third, check if we have all data for our frame; if not, exit early
if src.len() < msg_len + LEN_SIZE {
return Ok(None);
}
// Fourth, retrieve the nonce used with the ciphertext
let nonce = XNonce::from_slice(&src[LEN_SIZE..(NONCE_SIZE + LEN_SIZE)]);
// Fifth, acquire the encrypted & signed ciphertext
let ciphertext = &src[(NONCE_SIZE + LEN_SIZE)..(msg_len + LEN_SIZE)];
// Sixth, convert ciphertext back into our item
let item = self.cipher.decrypt(nonce, ciphertext);
// Seventh, advance so frame is no longer kept around
src.advance(LEN_SIZE + msg_len);
// Eighth, report an error if there is one
let item =
item.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?;
Ok(Some(item))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_should_fail_when_item_is_zero_bytes() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
let result = codec.encode(&[], &mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidInput => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
let len = buf.get_u64() as usize;
assert!(buf.len() > NONCE_SIZE, "Msg size not big enough");
assert_eq!(len, buf.len(), "Msg size does not match attached size");
}
#[test]
fn decode_should_return_none_if_data_smaller_than_or_equal_to_frame_length_field() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
buf.put_bytes(0, LEN_SIZE);
let result = codec.decode(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn decode_should_return_none_if_not_enough_data_for_frame() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
buf.put_u64(0);
let result = codec.decode(&mut buf);
assert!(
matches!(result, Ok(None)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn decode_should_fail_if_encoded_frame_length_is_smaller_than_nonce_plus_data() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// NONCE_SIZE + 1 is minimum for frame length
let mut buf = BytesMut::new();
buf.put_u64(NONCE_SIZE as u64);
buf.put_bytes(0, NONCE_SIZE);
let result = codec.decode(&mut buf);
match result {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn decode_should_advance_src_by_frame_size_even_if_frame_length_is_too_small() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes
let mut buf = BytesMut::new();
buf.put_u64(NONCE_SIZE as u64);
buf.put_bytes(0, NONCE_SIZE);
buf.put_bytes(0, 3);
assert!(
codec.decode(&mut buf).is_err(),
"Decode unexpectedly succeeded"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn decode_should_advance_src_by_frame_size_even_if_decryption_fails() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes
let mut buf = BytesMut::new();
buf.put_u64((NONCE_SIZE + 12) as u64);
buf.put_bytes(0, NONCE_SIZE);
buf.put_slice(b"hello, world");
buf.put_bytes(0, 3);
assert!(
codec.decode(&mut buf).is_err(),
"Decode unexpectedly succeeded"
);
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn decode_should_advance_src_by_frame_size_when_successful() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
// Add 3 extra bytes after a full frame
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
buf.put_bytes(0, 3);
assert!(codec.decode(&mut buf).is_ok(), "Decode unexpectedly failed");
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
}
#[test]
fn decode_should_return_some_byte_vec_when_successful() {
let key = SecretKey32::default();
let mut codec = XChaCha20Poly1305Codec::from(key);
let mut buf = BytesMut::new();
codec
.encode(b"hello, world", &mut buf)
.expect("Failed to encode");
let item = codec
.decode(&mut buf)
.expect("Failed to decode")
.expect("Item not properly captured");
assert_eq!(item, b"hello, world");
}
}

@ -0,0 +1,20 @@
mod any;
pub mod authentication;
mod connection;
mod destination;
mod listener;
mod map;
mod packet;
mod port;
mod transport;
pub(crate) mod utils;
pub use any::*;
pub(crate) use connection::Connection;
pub use connection::ConnectionId;
pub use destination::*;
pub use listener::*;
pub use map::*;
pub use packet::*;
pub use port::*;
pub use transport::*;

@ -0,0 +1,10 @@
mod authenticator;
mod handler;
mod keychain;
mod methods;
pub mod msg;
pub use authenticator::*;
pub use handler::*;
pub use keychain::*;
pub use methods::*;

@ -0,0 +1,672 @@
use super::{msg::*, AuthHandler};
use crate::common::{utils, FramedTransport, Transport};
use async_trait::async_trait;
use log::*;
use std::io;
/// Represents an interface for authenticating with a server.
#[async_trait]
pub trait Authenticate {
/// Performs authentication by leveraging the `handler` for any received challenge.
async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()>;
}
/// Represents an interface for submitting challenges for authentication.
#[async_trait]
pub trait Authenticator: Send {
/// Issues an initialization notice and returns the response indicating which authentication
/// methods to pursue
async fn initialize(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse>;
/// Issues a challenge and returns the answers to the `questions` asked.
async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse>;
/// Requests verification of some `kind` and `text`, returning true if passed verification.
async fn verify(&mut self, verification: Verification) -> io::Result<VerificationResponse>;
/// Reports information with no response expected.
async fn info(&mut self, info: Info) -> io::Result<()>;
/// Reports an error occurred during authentication, consuming the authenticator since no more
/// challenges should be issued.
async fn error(&mut self, error: Error) -> io::Result<()>;
/// Reports that the authentication has started for a specific method.
async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()>;
/// Reports that the authentication has finished successfully, consuming the authenticator
/// since no more challenges should be issued.
async fn finished(&mut self) -> io::Result<()>;
}
macro_rules! write_frame {
($transport:expr, $data:expr) => {{
let data = utils::serialize_to_vec(&$data)?;
if log_enabled!(Level::Trace) {
trace!("Writing data as frame: {data:?}");
}
$transport.write_frame(data).await?
}};
}
macro_rules! next_frame_as {
($transport:expr, $type:ident, $variant:ident) => {{
match { next_frame_as!($transport, $type) } {
$type::$variant(x) => x,
x => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Unexpected frame: {x:?}"),
))
}
}
}};
($transport:expr, $type:ident) => {{
let frame = $transport.read_frame().await?.ok_or_else(|| {
io::Error::new(
io::ErrorKind::UnexpectedEof,
concat!(
"Transport closed early waiting for frame of type ",
stringify!($type),
),
)
})?;
match utils::deserialize_from_slice::<$type>(frame.as_item()) {
Ok(frame) => frame,
Err(x) => {
if log_enabled!(Level::Trace) {
trace!(
"Failed to deserialize frame item as {}: {:?}",
stringify!($type),
frame.as_item()
);
}
Err(x)?;
unreachable!();
}
}
}};
}
#[async_trait]
impl<T> Authenticate for FramedTransport<T>
where
T: Transport,
{
async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()> {
loop {
trace!("Authenticate::authenticate waiting on next authentication frame");
match next_frame_as!(self, Authentication) {
Authentication::Initialization(x) => {
trace!("Authenticate::Initialization({x:?})");
let response = handler.on_initialization(x).await?;
write_frame!(self, AuthenticationResponse::Initialization(response));
}
Authentication::Challenge(x) => {
trace!("Authenticate::Challenge({x:?})");
let response = handler.on_challenge(x).await?;
write_frame!(self, AuthenticationResponse::Challenge(response));
}
Authentication::Verification(x) => {
trace!("Authenticate::Verify({x:?})");
let response = handler.on_verification(x).await?;
write_frame!(self, AuthenticationResponse::Verification(response));
}
Authentication::Info(x) => {
trace!("Authenticate::Info({x:?})");
handler.on_info(x).await?;
}
Authentication::Error(x) => {
trace!("Authenticate::Error({x:?})");
handler.on_error(x.clone()).await?;
if x.is_fatal() {
return Err(x.into_io_permission_denied());
}
}
Authentication::StartMethod(x) => {
trace!("Authenticate::StartMethod({x:?})");
handler.on_start_method(x).await?;
}
Authentication::Finished => {
trace!("Authenticate::Finished");
handler.on_finished().await?;
return Ok(());
}
}
}
}
}
#[async_trait]
impl<T> Authenticator for FramedTransport<T>
where
T: Transport,
{
async fn initialize(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse> {
trace!("Authenticator::initialize({initialization:?})");
write_frame!(self, Authentication::Initialization(initialization));
let response = next_frame_as!(self, AuthenticationResponse, Initialization);
Ok(response)
}
async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
trace!("Authenticator::challenge({challenge:?})");
write_frame!(self, Authentication::Challenge(challenge));
let response = next_frame_as!(self, AuthenticationResponse, Challenge);
Ok(response)
}
async fn verify(&mut self, verification: Verification) -> io::Result<VerificationResponse> {
trace!("Authenticator::verify({verification:?})");
write_frame!(self, Authentication::Verification(verification));
let response = next_frame_as!(self, AuthenticationResponse, Verification);
Ok(response)
}
async fn info(&mut self, info: Info) -> io::Result<()> {
trace!("Authenticator::info({info:?})");
write_frame!(self, Authentication::Info(info));
Ok(())
}
async fn error(&mut self, error: Error) -> io::Result<()> {
trace!("Authenticator::error({error:?})");
write_frame!(self, Authentication::Error(error));
Ok(())
}
async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
trace!("Authenticator::start_method({start_method:?})");
write_frame!(self, Authentication::StartMethod(start_method));
Ok(())
}
async fn finished(&mut self) -> io::Result<()> {
trace!("Authenticator::finished()");
write_frame!(self, Authentication::Finished);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::authentication::AuthMethodHandler;
use test_log::test;
use tokio::sync::mpsc;
#[async_trait]
trait TestAuthHandler {
async fn on_initialization(
&mut self,
_: Initialization,
) -> io::Result<InitializationResponse> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_start_method(&mut self, _: StartMethod) -> io::Result<()> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_finished(&mut self) -> io::Result<()> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_verification(&mut self, _: Verification) -> io::Result<VerificationResponse> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_info(&mut self, _: Info) -> io::Result<()> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_error(&mut self, _: Error) -> io::Result<()> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
}
#[async_trait]
impl<T: TestAuthHandler + Send> AuthHandler for T {
async fn on_initialization(
&mut self,
x: Initialization,
) -> io::Result<InitializationResponse> {
TestAuthHandler::on_initialization(self, x).await
}
async fn on_start_method(&mut self, x: StartMethod) -> io::Result<()> {
TestAuthHandler::on_start_method(self, x).await
}
async fn on_finished(&mut self) -> io::Result<()> {
TestAuthHandler::on_finished(self).await
}
}
#[async_trait]
impl<T: TestAuthHandler + Send> AuthMethodHandler for T {
async fn on_challenge(&mut self, x: Challenge) -> io::Result<ChallengeResponse> {
TestAuthHandler::on_challenge(self, x).await
}
async fn on_verification(&mut self, x: Verification) -> io::Result<VerificationResponse> {
TestAuthHandler::on_verification(self, x).await
}
async fn on_info(&mut self, x: Info) -> io::Result<()> {
TestAuthHandler::on_info(self, x).await
}
async fn on_error(&mut self, x: Error) -> io::Result<()> {
TestAuthHandler::on_error(self, x).await
}
}
macro_rules! auth_handler {
(@no_challenge @no_verification @tx($tx:ident, $ty:ty) $($methods:item)*) => {
auth_handler! {
@tx($tx, $ty)
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_verification(
&mut self,
_: Verification,
) -> io::Result<VerificationResponse> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
$($methods)*
}
};
(@no_challenge @tx($tx:ident, $ty:ty) $($methods:item)*) => {
auth_handler! {
@tx($tx, $ty)
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
$($methods)*
}
};
(@no_verification @tx($tx:ident, $ty:ty) $($methods:item)*) => {
auth_handler! {
@tx($tx, $ty)
async fn on_verification(
&mut self,
_: Verification,
) -> io::Result<VerificationResponse> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
$($methods)*
}
};
(@tx($tx:ident, $ty:ty) $($methods:item)*) => {{
#[allow(dead_code)]
struct __InlineAuthHandler {
tx: mpsc::Sender<$ty>,
}
#[async_trait]
impl TestAuthHandler for __InlineAuthHandler {
$($methods)*
}
__InlineAuthHandler { tx: $tx }
}};
}
#[test(tokio::test)]
async fn authenticator_initialization_should_be_able_to_successfully_complete_round_trip() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, _) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(auth_handler! {
@no_challenge
@no_verification
@tx(tx, ())
async fn on_initialization(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse> {
Ok(InitializationResponse {
methods: initialization.methods,
})
}
})
.await
.unwrap()
});
let response = t1
.initialize(Initialization {
methods: vec!["test method".to_string()].into_iter().collect(),
})
.await
.unwrap();
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
assert_eq!(
response,
InitializationResponse {
methods: vec!["test method".to_string()].into_iter().collect()
}
);
}
#[test(tokio::test)]
async fn authenticator_challenge_should_be_able_to_successfully_complete_round_trip() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, _) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(auth_handler! {
@no_verification
@tx(tx, ())
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
assert_eq!(challenge.questions, vec![Question {
label: "label".to_string(),
text: "text".to_string(),
options: vec![("question_key".to_string(), "question_value".to_string())]
.into_iter()
.collect(),
}]);
assert_eq!(
challenge.options,
vec![("key".to_string(), "value".to_string())].into_iter().collect(),
);
Ok(ChallengeResponse {
answers: vec!["some answer".to_string()].into_iter().collect(),
})
}
})
.await
.unwrap()
});
let response = t1
.challenge(Challenge {
questions: vec![Question {
label: "label".to_string(),
text: "text".to_string(),
options: vec![("question_key".to_string(), "question_value".to_string())]
.into_iter()
.collect(),
}],
options: vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
})
.await
.unwrap();
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
assert_eq!(
response,
ChallengeResponse {
answers: vec!["some answer".to_string()],
}
);
}
#[test(tokio::test)]
async fn authenticator_verification_should_be_able_to_successfully_complete_round_trip() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, _) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(auth_handler! {
@no_challenge
@tx(tx, ())
async fn on_verification(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse> {
assert_eq!(verification.kind, VerificationKind::Host);
assert_eq!(verification.text, "some text");
Ok(VerificationResponse {
valid: true,
})
}
})
.await
.unwrap()
});
let response = t1
.verify(Verification {
kind: VerificationKind::Host,
text: "some text".to_string(),
})
.await
.unwrap();
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
assert_eq!(response, VerificationResponse { valid: true });
}
#[test(tokio::test)]
async fn authenticator_info_should_be_able_to_be_sent_to_auth_handler() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(auth_handler! {
@no_challenge
@no_verification
@tx(tx, Info)
async fn on_info(
&mut self,
info: Info,
) -> io::Result<()> {
self.tx.send(info).await.unwrap();
Ok(())
}
})
.await
.unwrap()
});
t1.info(Info {
text: "some text".to_string(),
})
.await
.unwrap();
assert_eq!(
rx.recv().await.unwrap(),
Info {
text: "some text".to_string()
}
);
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
}
#[test(tokio::test)]
async fn authenticator_error_should_be_able_to_be_sent_to_auth_handler() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(auth_handler! {
@no_challenge
@no_verification
@tx(tx, Error)
async fn on_error(&mut self, error: Error) -> io::Result<()> {
self.tx.send(error).await.unwrap();
Ok(())
}
})
.await
.unwrap()
});
t1.error(Error {
kind: ErrorKind::Error,
text: "some text".to_string(),
})
.await
.unwrap();
assert_eq!(
rx.recv().await.unwrap(),
Error {
kind: ErrorKind::Error,
text: "some text".to_string(),
}
);
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
}
#[test(tokio::test)]
async fn auth_handler_received_error_should_fail_auth_handler_if_fatal() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(auth_handler! {
@no_challenge
@no_verification
@tx(tx, Error)
async fn on_error(&mut self, error: Error) -> io::Result<()> {
self.tx.send(error).await.unwrap();
Ok(())
}
})
.await
.unwrap()
});
t1.error(Error {
kind: ErrorKind::Fatal,
text: "some text".to_string(),
})
.await
.unwrap();
assert_eq!(
rx.recv().await.unwrap(),
Error {
kind: ErrorKind::Fatal,
text: "some text".to_string(),
}
);
// Verify that the handler exited with an error
task.await.unwrap_err();
}
#[test(tokio::test)]
async fn authenticator_start_method_should_be_able_to_be_sent_to_auth_handler() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(auth_handler! {
@no_challenge
@no_verification
@tx(tx, StartMethod)
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
self.tx.send(start_method).await.unwrap();
Ok(())
}
})
.await
.unwrap()
});
t1.start_method(StartMethod {
method: "some method".to_string(),
})
.await
.unwrap();
assert_eq!(
rx.recv().await.unwrap(),
StartMethod {
method: "some method".to_string()
}
);
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
}
#[test(tokio::test)]
async fn authenticator_finished_should_be_able_to_be_sent_to_auth_handler() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(auth_handler! {
@no_challenge
@no_verification
@tx(tx, ())
async fn on_finished(&mut self) -> io::Result<()> {
self.tx.send(()).await.unwrap();
Ok(())
}
})
.await
.unwrap()
});
t1.finished().await.unwrap();
// Verify that the callback was triggered
rx.recv().await.unwrap();
// Finished should signal that the handler completed successfully
task.await.unwrap();
}
}

@ -0,0 +1,343 @@
use super::msg::*;
use crate::common::authentication::Authenticator;
use crate::common::HeapSecretKey;
use async_trait::async_trait;
use std::collections::HashMap;
use std::io;
mod methods;
pub use methods::*;
/// Interface for a handler of authentication requests for all methods.
#[async_trait]
pub trait AuthHandler: AuthMethodHandler + Send {
/// Callback when authentication is beginning, providing available authentication methods and
/// returning selected authentication methods to pursue.
async fn on_initialization(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse> {
Ok(InitializationResponse {
methods: initialization.methods,
})
}
/// Callback when authentication starts for a specific method.
#[allow(unused_variables)]
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
Ok(())
}
/// Callback when authentication is finished and no more requests will be received.
async fn on_finished(&mut self) -> io::Result<()> {
Ok(())
}
}
/// Dummy implementation of [`AuthHandler`] where any challenge or verification request will
/// instantly fail.
pub struct DummyAuthHandler;
#[async_trait]
impl AuthHandler for DummyAuthHandler {}
#[async_trait]
impl AuthMethodHandler for DummyAuthHandler {
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_verification(&mut self, _: Verification) -> io::Result<VerificationResponse> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_info(&mut self, _: Info) -> io::Result<()> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
async fn on_error(&mut self, _: Error) -> io::Result<()> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
}
/// Implementation of [`AuthHandler`] that uses the same [`AuthMethodHandler`] for all methods.
pub struct SingleAuthHandler(Box<dyn AuthMethodHandler>);
impl SingleAuthHandler {
pub fn new<T: AuthMethodHandler + 'static>(method_handler: T) -> Self {
Self(Box::new(method_handler))
}
}
#[async_trait]
impl AuthHandler for SingleAuthHandler {}
#[async_trait]
impl AuthMethodHandler for SingleAuthHandler {
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
self.0.on_challenge(challenge).await
}
async fn on_verification(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse> {
self.0.on_verification(verification).await
}
async fn on_info(&mut self, info: Info) -> io::Result<()> {
self.0.on_info(info).await
}
async fn on_error(&mut self, error: Error) -> io::Result<()> {
self.0.on_error(error).await
}
}
/// Implementation of [`AuthHandler`] that maintains a map of [`AuthMethodHandler`] implementations
/// for specific methods, invoking [`on_challenge`], [`on_verification`], [`on_info`], and
/// [`on_error`] for a specific handler based on an associated id.
///
/// [`on_challenge`]: AuthMethodHandler::on_challenge
/// [`on_verification`]: AuthMethodHandler::on_verification
/// [`on_info`]: AuthMethodHandler::on_info
/// [`on_error`]: AuthMethodHandler::on_error
pub struct AuthHandlerMap {
active: String,
map: HashMap<&'static str, Box<dyn AuthMethodHandler>>,
}
impl AuthHandlerMap {
/// Creates a new, empty map of auth method handlers.
pub fn new() -> Self {
Self {
active: String::new(),
map: HashMap::new(),
}
}
/// Returns the `id` of the active [`AuthMethodHandler`].
pub fn active_id(&self) -> &str {
&self.active
}
/// Sets the active [`AuthMethodHandler`] by its `id`.
pub fn set_active_id(&mut self, id: impl Into<String>) {
self.active = id.into();
}
/// Inserts the specified `handler` into the map, associating it with `id` for determining the
/// method that would trigger this handler.
pub fn insert_method_handler<T: AuthMethodHandler + 'static>(
&mut self,
id: &'static str,
handler: T,
) -> Option<Box<dyn AuthMethodHandler>> {
self.map.insert(id, Box::new(handler))
}
/// Removes a handler with the associated `id`.
pub fn remove_method_handler(
&mut self,
id: &'static str,
) -> Option<Box<dyn AuthMethodHandler>> {
self.map.remove(id)
}
/// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`,
/// returning an error if no handler for the active id is found.
pub fn get_mut_active_method_handler_or_error(
&mut self,
) -> io::Result<&mut (dyn AuthMethodHandler + 'static)> {
let id = self.active.clone();
self.get_mut_active_method_handler().ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, format!("No active handler for {id}"))
})
}
/// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`.
pub fn get_mut_active_method_handler(
&mut self,
) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
// TODO: Optimize this
self.get_mut_method_handler(&self.active.clone())
}
/// Retrieves a mutable reference to the [`AuthMethodHandler`] with the specified `id`.
pub fn get_mut_method_handler(
&mut self,
id: &str,
) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
self.map.get_mut(id).map(|h| h.as_mut())
}
}
impl AuthHandlerMap {
/// Consumes the map, returning a new map that supports the `static_key` method.
pub fn with_static_key(mut self, key: impl Into<HeapSecretKey>) -> Self {
self.insert_method_handler("static_key", StaticKeyAuthMethodHandler::simple(key));
self
}
}
impl Default for AuthHandlerMap {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AuthHandler for AuthHandlerMap {
async fn on_initialization(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse> {
let methods = initialization
.methods
.into_iter()
.filter(|method| self.map.contains_key(method.as_str()))
.collect();
Ok(InitializationResponse { methods })
}
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
self.set_active_id(start_method.method);
Ok(())
}
async fn on_finished(&mut self) -> io::Result<()> {
Ok(())
}
}
#[async_trait]
impl AuthMethodHandler for AuthHandlerMap {
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
let handler = self.get_mut_active_method_handler_or_error()?;
handler.on_challenge(challenge).await
}
async fn on_verification(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse> {
let handler = self.get_mut_active_method_handler_or_error()?;
handler.on_verification(verification).await
}
async fn on_info(&mut self, info: Info) -> io::Result<()> {
let handler = self.get_mut_active_method_handler_or_error()?;
handler.on_info(info).await
}
async fn on_error(&mut self, error: Error) -> io::Result<()> {
let handler = self.get_mut_active_method_handler_or_error()?;
handler.on_error(error).await
}
}
/// Implementation of [`AuthHandler`] that redirects all requests to an [`Authenticator`].
pub struct ProxyAuthHandler<'a>(&'a mut dyn Authenticator);
impl<'a> ProxyAuthHandler<'a> {
pub fn new(authenticator: &'a mut dyn Authenticator) -> Self {
Self(authenticator)
}
}
#[async_trait]
impl<'a> AuthHandler for ProxyAuthHandler<'a> {
async fn on_initialization(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse> {
Authenticator::initialize(self.0, initialization).await
}
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
Authenticator::start_method(self.0, start_method).await
}
async fn on_finished(&mut self) -> io::Result<()> {
Authenticator::finished(self.0).await
}
}
#[async_trait]
impl<'a> AuthMethodHandler for ProxyAuthHandler<'a> {
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
Authenticator::challenge(self.0, challenge).await
}
async fn on_verification(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse> {
Authenticator::verify(self.0, verification).await
}
async fn on_info(&mut self, info: Info) -> io::Result<()> {
Authenticator::info(self.0, info).await
}
async fn on_error(&mut self, error: Error) -> io::Result<()> {
Authenticator::error(self.0, error).await
}
}
/// Implementation of [`AuthHandler`] that holds a mutable reference to another [`AuthHandler`]
/// trait object to use underneath.
pub struct DynAuthHandler<'a>(&'a mut dyn AuthHandler);
impl<'a> DynAuthHandler<'a> {
pub fn new(handler: &'a mut dyn AuthHandler) -> Self {
Self(handler)
}
}
impl<'a, T: AuthHandler> From<&'a mut T> for DynAuthHandler<'a> {
fn from(handler: &'a mut T) -> Self {
Self::new(handler as &mut dyn AuthHandler)
}
}
#[async_trait]
impl<'a> AuthHandler for DynAuthHandler<'a> {
async fn on_initialization(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse> {
self.0.on_initialization(initialization).await
}
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
self.0.on_start_method(start_method).await
}
async fn on_finished(&mut self) -> io::Result<()> {
self.0.on_finished().await
}
}
#[async_trait]
impl<'a> AuthMethodHandler for DynAuthHandler<'a> {
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
self.0.on_challenge(challenge).await
}
async fn on_verification(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse> {
self.0.on_verification(verification).await
}
async fn on_info(&mut self, info: Info) -> io::Result<()> {
self.0.on_info(info).await
}
async fn on_error(&mut self, error: Error) -> io::Result<()> {
self.0.on_error(error).await
}
}

@ -0,0 +1,33 @@
use super::{
Challenge, ChallengeResponse, Error, Info, Verification, VerificationKind, VerificationResponse,
};
use async_trait::async_trait;
use std::io;
/// Interface for a handler of authentication requests for a specific authentication method.
#[async_trait]
pub trait AuthMethodHandler: Send {
/// Callback when a challenge is received, returning answers to the given questions.
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse>;
/// Callback when a verification request is received, returning true if approvided or false if
/// unapproved.
async fn on_verification(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse>;
/// Callback when information is received. To fail, return an error from this function.
async fn on_info(&mut self, info: Info) -> io::Result<()>;
/// Callback when an error is received. Regardless of the result returned, this will terminate
/// the authenticator. In the situation where a custom error would be preferred, have this
/// callback return an error.
async fn on_error(&mut self, error: Error) -> io::Result<()>;
}
mod prompt;
pub use prompt::*;
mod static_key;
pub use static_key::*;

@ -0,0 +1,88 @@
use super::{
AuthMethodHandler, Challenge, ChallengeResponse, Error, Info, Verification, VerificationKind,
VerificationResponse,
};
use async_trait::async_trait;
use log::*;
use std::io;
/// Blocking implementation of [`AuthMethodHandler`] that uses prompts to communicate challenge &
/// verification requests, receiving responses to relay back.
pub struct PromptAuthMethodHandler<T, U> {
text_prompt: T,
password_prompt: U,
}
impl<T, U> PromptAuthMethodHandler<T, U> {
pub fn new(text_prompt: T, password_prompt: U) -> Self {
Self {
text_prompt,
password_prompt,
}
}
}
#[async_trait]
impl<T, U> AuthMethodHandler for PromptAuthMethodHandler<T, U>
where
T: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
U: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
{
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
trace!("on_challenge({challenge:?})");
let mut answers = Vec::new();
for question in challenge.questions.iter() {
// Contains all prompt lines including same line
let mut lines = question.text.split('\n').collect::<Vec<_>>();
// Line that is prompt on same line as answer
let line = lines.pop().unwrap();
// Go ahead and display all other lines
for line in lines.into_iter() {
eprintln!("{}", line);
}
// Get an answer from user input, or use a blank string as an answer
// if we fail to get input from the user
let answer = (self.password_prompt)(line).unwrap_or_default();
answers.push(answer);
}
Ok(ChallengeResponse { answers })
}
async fn on_verification(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse> {
trace!("on_verify({verification:?})");
match verification.kind {
VerificationKind::Host => {
eprintln!("{}", verification.text);
let answer = (self.text_prompt)("Enter [y/N]> ")?;
trace!("Verify? Answer = '{answer}'");
Ok(VerificationResponse {
valid: matches!(answer.trim(), "y" | "Y" | "yes" | "YES"),
})
}
x => {
error!("Unsupported verify kind: {x}");
Ok(VerificationResponse { valid: false })
}
}
}
async fn on_info(&mut self, info: Info) -> io::Result<()> {
trace!("on_info({info:?})");
println!("{}", info.text);
Ok(())
}
async fn on_error(&mut self, error: Error) -> io::Result<()> {
trace!("on_error({error:?})");
eprintln!("{}: {}", error.kind, error.text);
Ok(())
}
}

@ -0,0 +1,171 @@
use super::{
AuthMethodHandler, Challenge, ChallengeResponse, Error, Info, Verification,
VerificationResponse,
};
use crate::common::HeapSecretKey;
use async_trait::async_trait;
use log::*;
use std::io;
/// Implementation of [`AuthMethodHandler`] that answers challenge requests using a static
/// [`HeapSecretKey`]. All other portions of method authentication are handled by another
/// [`AuthMethodHandler`].
pub struct StaticKeyAuthMethodHandler {
key: HeapSecretKey,
handler: Box<dyn AuthMethodHandler>,
}
impl StaticKeyAuthMethodHandler {
/// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static
/// `key`. All other requests are passed to the `handler`.
pub fn new<T: AuthMethodHandler + 'static>(key: impl Into<HeapSecretKey>, handler: T) -> Self {
Self {
key: key.into(),
handler: Box::new(handler),
}
}
/// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static
/// `key`. All other requests are passed automatically, meaning that verification is always
/// approvide and info/errors are ignored.
pub fn simple(key: impl Into<HeapSecretKey>) -> Self {
Self::new(key, {
struct __AuthMethodHandler;
#[async_trait]
impl AuthMethodHandler for __AuthMethodHandler {
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
unreachable!("on_challenge should be handled by StaticKeyAuthMethodHandler");
}
async fn on_verification(
&mut self,
_: Verification,
) -> io::Result<VerificationResponse> {
Ok(VerificationResponse { valid: true })
}
async fn on_info(&mut self, _: Info) -> io::Result<()> {
Ok(())
}
async fn on_error(&mut self, _: Error) -> io::Result<()> {
Ok(())
}
}
__AuthMethodHandler
})
}
}
#[async_trait]
impl AuthMethodHandler for StaticKeyAuthMethodHandler {
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
trace!("on_challenge({challenge:?})");
let mut answers = Vec::new();
for question in challenge.questions.iter() {
// Only challenges with a "key" label are allowed, all else will fail
if question.label != "key" {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Only 'key' challenges are supported",
));
}
answers.push(self.key.to_string());
}
Ok(ChallengeResponse { answers })
}
async fn on_verification(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse> {
trace!("on_verify({verification:?})");
self.handler.on_verification(verification).await
}
async fn on_info(&mut self, info: Info) -> io::Result<()> {
trace!("on_info({info:?})");
self.handler.on_info(info).await
}
async fn on_error(&mut self, error: Error) -> io::Result<()> {
trace!("on_error({error:?})");
self.handler.on_error(error).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::authentication::msg::{ErrorKind, Question, VerificationKind};
use test_log::test;
#[test(tokio::test)]
async fn on_challenge_should_fail_if_non_key_question_received() {
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
handler
.on_challenge(Challenge {
questions: vec![Question::new("test")],
options: Default::default(),
})
.await
.unwrap_err();
}
#[test(tokio::test)]
async fn on_challenge_should_answer_with_stringified_key_for_key_questions() {
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
let response = handler
.on_challenge(Challenge {
questions: vec![Question::new("key")],
options: Default::default(),
})
.await
.unwrap();
assert_eq!(response.answers.len(), 1, "Wrong answer set received");
assert!(!response.answers[0].is_empty(), "Empty answer being sent");
}
#[test(tokio::test)]
async fn on_verification_should_leverage_fallback_handler() {
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
let response = handler
.on_verification(Verification {
kind: VerificationKind::Host,
text: "host".to_string(),
})
.await
.unwrap();
assert!(response.valid, "Unexpected result from fallback handler");
}
#[test(tokio::test)]
async fn on_info_should_leverage_fallback_handler() {
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
handler
.on_info(Info {
text: "info".to_string(),
})
.await
.unwrap();
}
#[test(tokio::test)]
async fn on_error_should_leverage_fallback_handler() {
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
handler
.on_error(Error {
kind: ErrorKind::Error,
text: "text".to_string(),
})
.await
.unwrap();
}
}

@ -0,0 +1,156 @@
use crate::common::HeapSecretKey;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Represents the result of a request to the database.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum KeychainResult<T> {
/// Id was not found in the database.
InvalidId,
/// Password match for an id failed.
InvalidPassword,
/// Successful match of id and password, removing from keychain and returning data `T`.
Ok(T),
}
impl<T> KeychainResult<T> {
pub fn is_invalid_id(&self) -> bool {
matches!(self, Self::InvalidId)
}
pub fn is_invalid_password(&self) -> bool {
matches!(self, Self::InvalidPassword)
}
pub fn is_invalid(&self) -> bool {
matches!(self, Self::InvalidId | Self::InvalidPassword)
}
pub fn is_ok(&self) -> bool {
matches!(self, Self::Ok(_))
}
pub fn into_ok(self) -> Option<T> {
match self {
Self::Ok(x) => Some(x),
_ => None,
}
}
}
impl<T> From<KeychainResult<T>> for Option<T> {
fn from(result: KeychainResult<T>) -> Self {
result.into_ok()
}
}
/// Manages keys with associated ids. Cloning will result in a copy pointing to the same underlying
/// storage, which enables support of managing the keys across multiple threads.
#[derive(Debug)]
pub struct Keychain<T = ()> {
map: Arc<RwLock<HashMap<String, (HeapSecretKey, T)>>>,
}
impl<T> Clone for Keychain<T> {
fn clone(&self) -> Self {
Self {
map: Arc::clone(&self.map),
}
}
}
impl<T> Keychain<T> {
/// Creates a new keychain without any keys.
pub fn new() -> Self {
Self {
map: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Stores a new `key` and `data` by a given `id`, returning the old data associated with the
/// id if there was one already registered.
pub async fn insert(&self, id: impl Into<String>, key: HeapSecretKey, data: T) -> Option<T> {
self.map
.write()
.await
.insert(id.into(), (key, data))
.map(|(_, data)| data)
}
/// Checks if there is an `id` stored within the keychain.
pub async fn has_id(&self, id: impl AsRef<str>) -> bool {
self.map.read().await.contains_key(id.as_ref())
}
/// Checks if there is a key with the given `id` that matches the provided `key`.
pub async fn has_key(&self, id: impl AsRef<str>, key: impl PartialEq<HeapSecretKey>) -> bool {
self.map
.read()
.await
.get(id.as_ref())
.map(|(k, _)| key.eq(k))
.unwrap_or(false)
}
/// Removes a key and its data by a given `id`, returning the data if the `id` exists.
pub async fn remove(&self, id: impl AsRef<str>) -> Option<T> {
self.map
.write()
.await
.remove(id.as_ref())
.map(|(_, data)| data)
}
/// Checks if there is a key with the given `id` that matches the provided `key`, returning the
/// data if the `id` exists and the `key` matches.
pub async fn remove_if_has_key(
&self,
id: impl AsRef<str>,
key: impl PartialEq<HeapSecretKey>,
) -> KeychainResult<T> {
let id = id.as_ref();
let mut lock = self.map.write().await;
match lock.get(id) {
Some((k, _)) if key.eq(k) => KeychainResult::Ok(lock.remove(id).unwrap().1),
Some(_) => KeychainResult::InvalidPassword,
None => KeychainResult::InvalidId,
}
}
}
impl Keychain<()> {
/// Stores a new `key by a given `id`.
pub async fn put(&self, id: impl Into<String>, key: HeapSecretKey) {
self.insert(id, key, ()).await;
}
}
impl Default for Keychain {
fn default() -> Self {
Self::new()
}
}
impl<T> From<HashMap<String, (HeapSecretKey, T)>> for Keychain<T> {
/// Creates a new keychain populated with the provided `map`.
fn from(map: HashMap<String, (HeapSecretKey, T)>) -> Self {
Self {
map: Arc::new(RwLock::new(map)),
}
}
}
impl From<HashMap<String, HeapSecretKey>> for Keychain<()> {
/// Creates a new keychain populated with the provided `map`.
fn from(map: HashMap<String, HeapSecretKey>) -> Self {
Self::from(
map.into_iter()
.map(|(id, key)| (id, (key, ())))
.collect::<HashMap<String, (HeapSecretKey, ())>>(),
)
}
}

@ -0,0 +1,376 @@
use super::{super::HeapSecretKey, msg::*, Authenticator};
use async_trait::async_trait;
use log::*;
use std::collections::HashMap;
use std::io;
mod none;
mod static_key;
pub use none::*;
pub use static_key::*;
/// Supports authenticating using a variety of methods
pub struct Verifier {
methods: HashMap<&'static str, Box<dyn AuthenticationMethod>>,
}
impl Verifier {
pub fn new<I>(methods: I) -> Self
where
I: IntoIterator<Item = Box<dyn AuthenticationMethod>>,
{
let mut m = HashMap::new();
for method in methods {
m.insert(method.id(), method);
}
Self { methods: m }
}
/// Creates a verifier with no methods.
pub fn empty() -> Self {
Self {
methods: HashMap::new(),
}
}
/// Creates a verifier that uses the [`NoneAuthenticationMethod`] exclusively.
pub fn none() -> Self {
Self::new(vec![
Box::new(NoneAuthenticationMethod::new()) as Box<dyn AuthenticationMethod>
])
}
/// Creates a verifier that uses the [`StaticKeyAuthenticationMethod`] exclusively.
pub fn static_key(key: impl Into<HeapSecretKey>) -> Self {
Self::new(vec![
Box::new(StaticKeyAuthenticationMethod::new(key)) as Box<dyn AuthenticationMethod>
])
}
/// Returns an iterator over the ids of the methods supported by the verifier
pub fn methods(&self) -> impl Iterator<Item = &'static str> + '_ {
self.methods.keys().copied()
}
/// Attempts to verify by submitting challenges using the `authenticator` provided. Returns the
/// id of the authentication method that succeeded. Fails if no authentication method succeeds.
pub async fn verify(&self, authenticator: &mut dyn Authenticator) -> io::Result<&'static str> {
// Initiate the process to get methods to use
let response = authenticator
.initialize(Initialization {
methods: self.methods.keys().map(ToString::to_string).collect(),
})
.await?;
for method in response.methods {
match self.methods.get(method.as_str()) {
Some(method) => {
// Report the authentication method
authenticator
.start_method(StartMethod {
method: method.id().to_string(),
})
.await?;
// Perform the actual authentication
if method.authenticate(authenticator).await.is_ok() {
authenticator.finished().await?;
return Ok(method.id());
}
}
None => {
trace!("Skipping authentication {method} as it is not available or supported");
}
}
}
Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"No authentication method succeeded",
))
}
}
impl From<Vec<Box<dyn AuthenticationMethod>>> for Verifier {
fn from(methods: Vec<Box<dyn AuthenticationMethod>>) -> Self {
Self::new(methods)
}
}
/// Represents an interface to authenticate using some method
#[async_trait]
pub trait AuthenticationMethod: Send + Sync {
/// Returns a unique id to distinguish the method from other methods
fn id(&self) -> &'static str;
/// Performs authentication using the `authenticator` to submit challenges and other
/// information based on the authentication method
async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::FramedTransport;
use test_log::test;
struct SuccessAuthenticationMethod;
#[async_trait]
impl AuthenticationMethod for SuccessAuthenticationMethod {
fn id(&self) -> &'static str {
"success"
}
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
Ok(())
}
}
struct FailAuthenticationMethod;
#[async_trait]
impl AuthenticationMethod for FailAuthenticationMethod {
fn id(&self) -> &'static str {
"fail"
}
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
Err(io::Error::from(io::ErrorKind::Other))
}
}
#[test(tokio::test)]
async fn verifier_should_fail_to_verify_if_initialization_fails() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame(b"invalid initialization response")
.await
.unwrap();
let methods: Vec<Box<dyn AuthenticationMethod>> =
vec![Box::new(SuccessAuthenticationMethod)];
let verifier = Verifier::from(methods);
verifier.verify(&mut t1).await.unwrap_err();
}
#[test(tokio::test)]
async fn verifier_should_fail_to_verify_if_fails_to_send_finished_indicator_after_success() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Initialization(
InitializationResponse {
methods: vec![SuccessAuthenticationMethod.id().to_string()]
.into_iter()
.collect(),
},
))
.await
.unwrap();
// Then drop the transport so it cannot receive anything else
drop(t2);
let methods: Vec<Box<dyn AuthenticationMethod>> =
vec![Box::new(SuccessAuthenticationMethod)];
let verifier = Verifier::from(methods);
assert_eq!(
verifier.verify(&mut t1).await.unwrap_err().kind(),
io::ErrorKind::WriteZero
);
}
#[test(tokio::test)]
async fn verifier_should_fail_to_verify_if_has_no_authentication_methods() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Initialization(
InitializationResponse {
methods: vec![SuccessAuthenticationMethod.id().to_string()]
.into_iter()
.collect(),
},
))
.await
.unwrap();
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![];
let verifier = Verifier::from(methods);
verifier.verify(&mut t1).await.unwrap_err();
}
#[test(tokio::test)]
async fn verifier_should_fail_to_verify_if_initialization_yields_no_valid_authentication_methods(
) {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Initialization(
InitializationResponse {
methods: vec!["other".to_string()].into_iter().collect(),
},
))
.await
.unwrap();
let methods: Vec<Box<dyn AuthenticationMethod>> =
vec![Box::new(SuccessAuthenticationMethod)];
let verifier = Verifier::from(methods);
verifier.verify(&mut t1).await.unwrap_err();
}
#[test(tokio::test)]
async fn verifier_should_fail_to_verify_if_no_authentication_method_succeeds() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Initialization(
InitializationResponse {
methods: vec![FailAuthenticationMethod.id().to_string()]
.into_iter()
.collect(),
},
))
.await
.unwrap();
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![Box::new(FailAuthenticationMethod)];
let verifier = Verifier::from(methods);
verifier.verify(&mut t1).await.unwrap_err();
}
#[test(tokio::test)]
async fn verifier_should_return_id_of_authentication_method_upon_success() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Initialization(
InitializationResponse {
methods: vec![SuccessAuthenticationMethod.id().to_string()]
.into_iter()
.collect(),
},
))
.await
.unwrap();
let methods: Vec<Box<dyn AuthenticationMethod>> =
vec![Box::new(SuccessAuthenticationMethod)];
let verifier = Verifier::from(methods);
assert_eq!(
verifier.verify(&mut t1).await.unwrap(),
SuccessAuthenticationMethod.id()
);
}
#[test(tokio::test)]
async fn verifier_should_try_authentication_methods_in_order_until_one_succeeds() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Initialization(
InitializationResponse {
methods: vec![
FailAuthenticationMethod.id().to_string(),
SuccessAuthenticationMethod.id().to_string(),
]
.into_iter()
.collect(),
},
))
.await
.unwrap();
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
Box::new(FailAuthenticationMethod),
Box::new(SuccessAuthenticationMethod),
];
let verifier = Verifier::from(methods);
assert_eq!(
verifier.verify(&mut t1).await.unwrap(),
SuccessAuthenticationMethod.id()
);
}
#[test(tokio::test)]
async fn verifier_should_send_start_method_before_attempting_each_method() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Initialization(
InitializationResponse {
methods: vec![
FailAuthenticationMethod.id().to_string(),
SuccessAuthenticationMethod.id().to_string(),
]
.into_iter()
.collect(),
},
))
.await
.unwrap();
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
Box::new(FailAuthenticationMethod),
Box::new(SuccessAuthenticationMethod),
];
Verifier::from(methods).verify(&mut t1).await.unwrap();
// Check that we get a start method for each of the attempted methods
match t2.read_frame_as::<Authentication>().await.unwrap().unwrap() {
Authentication::Initialization(_) => (),
x => panic!("Unexpected response: {x:?}"),
}
match t2.read_frame_as::<Authentication>().await.unwrap().unwrap() {
Authentication::StartMethod(x) => assert_eq!(x.method, FailAuthenticationMethod.id()),
x => panic!("Unexpected response: {x:?}"),
}
match t2.read_frame_as::<Authentication>().await.unwrap().unwrap() {
Authentication::StartMethod(x) => {
assert_eq!(x.method, SuccessAuthenticationMethod.id())
}
x => panic!("Unexpected response: {x:?}"),
}
}
#[test(tokio::test)]
async fn verifier_should_send_finished_when_a_method_succeeds() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Initialization(
InitializationResponse {
methods: vec![
FailAuthenticationMethod.id().to_string(),
SuccessAuthenticationMethod.id().to_string(),
]
.into_iter()
.collect(),
},
))
.await
.unwrap();
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
Box::new(FailAuthenticationMethod),
Box::new(SuccessAuthenticationMethod),
];
Verifier::from(methods).verify(&mut t1).await.unwrap();
// Clear out the initialization and start methods
t2.read_frame_as::<Authentication>().await.unwrap().unwrap();
t2.read_frame_as::<Authentication>().await.unwrap().unwrap();
t2.read_frame_as::<Authentication>().await.unwrap().unwrap();
match t2.read_frame_as::<Authentication>().await.unwrap().unwrap() {
Authentication::Finished => (),
x => panic!("Unexpected response: {x:?}"),
}
}
}

@ -0,0 +1,32 @@
use super::{AuthenticationMethod, Authenticator};
use async_trait::async_trait;
use std::io;
/// Authenticaton method for a static secret key
#[derive(Clone, Debug)]
pub struct NoneAuthenticationMethod;
impl NoneAuthenticationMethod {
#[inline]
pub fn new() -> Self {
Self
}
}
impl Default for NoneAuthenticationMethod {
#[inline]
fn default() -> Self {
Self
}
}
#[async_trait]
impl AuthenticationMethod for NoneAuthenticationMethod {
fn id(&self) -> &'static str {
"none"
}
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
Ok(())
}
}

@ -0,0 +1,129 @@
use super::{AuthenticationMethod, Authenticator, Challenge, Error, Question};
use crate::common::HeapSecretKey;
use async_trait::async_trait;
use std::io;
/// Authenticaton method for a static secret key
#[derive(Clone, Debug)]
pub struct StaticKeyAuthenticationMethod {
key: HeapSecretKey,
}
impl StaticKeyAuthenticationMethod {
#[inline]
pub fn new(key: impl Into<HeapSecretKey>) -> Self {
Self { key: key.into() }
}
}
#[async_trait]
impl AuthenticationMethod for StaticKeyAuthenticationMethod {
fn id(&self) -> &'static str {
"static_key"
}
async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()> {
let response = authenticator
.challenge(Challenge {
questions: vec![Question {
label: "key".to_string(),
text: "Provide a key: ".to_string(),
options: Default::default(),
}],
options: Default::default(),
})
.await?;
if response.answers.is_empty() {
return Err(Error::non_fatal("missing answer").into_io_permission_denied());
}
match response
.answers
.into_iter()
.next()
.unwrap()
.parse::<HeapSecretKey>()
{
Ok(key) if key == self.key => Ok(()),
_ => Err(Error::non_fatal("answer does not match key").into_io_permission_denied()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::{
authentication::msg::{AuthenticationResponse, ChallengeResponse},
FramedTransport,
};
use test_log::test;
#[test(tokio::test)]
async fn authenticate_should_fail_if_key_challenge_fails() {
let method = StaticKeyAuthenticationMethod::new(b"".to_vec());
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up an invalid frame for our challenge to ensure it fails
t2.write_frame(b"invalid initialization response")
.await
.unwrap();
assert_eq!(
method.authenticate(&mut t1).await.unwrap_err().kind(),
io::ErrorKind::InvalidData
);
}
#[test(tokio::test)]
async fn authenticate_should_fail_if_no_answer_included_in_challenge_response() {
let method = StaticKeyAuthenticationMethod::new(b"".to_vec());
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Challenge(ChallengeResponse {
answers: Vec::new(),
}))
.await
.unwrap();
assert_eq!(
method.authenticate(&mut t1).await.unwrap_err().kind(),
io::ErrorKind::PermissionDenied
);
}
#[test(tokio::test)]
async fn authenticate_should_fail_if_answer_does_not_match_key() {
let method = StaticKeyAuthenticationMethod::new(b"answer".to_vec());
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Challenge(ChallengeResponse {
answers: vec![HeapSecretKey::from(b"some key".to_vec()).to_string()],
}))
.await
.unwrap();
assert_eq!(
method.authenticate(&mut t1).await.unwrap_err().kind(),
io::ErrorKind::PermissionDenied
);
}
#[test(tokio::test)]
async fn authenticate_should_succeed_if_answer_matches_key() {
let method = StaticKeyAuthenticationMethod::new(b"answer".to_vec());
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up a response to the initialization request
t2.write_frame_for(&AuthenticationResponse::Challenge(ChallengeResponse {
answers: vec![HeapSecretKey::from(b"answer".to_vec()).to_string()],
}))
.await
.unwrap();
method.authenticate(&mut t1).await.unwrap();
}
}

@ -0,0 +1,216 @@
use derive_more::{Display, Error, From};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Represents messages from an authenticator that act as initiators such as providing
/// a challenge, verifying information, presenting information, or highlighting an error
#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum Authentication {
/// Indicates the beginning of authentication, providing available methods
#[serde(rename = "auth_initialization")]
Initialization(Initialization),
/// Indicates that authentication is starting for the specific `method`
#[serde(rename = "auth_start_method")]
StartMethod(StartMethod),
/// Issues a challenge to be answered
#[serde(rename = "auth_challenge")]
Challenge(Challenge),
/// Requests verification of some text
#[serde(rename = "auth_verification")]
Verification(Verification),
/// Reports some information associated with authentication
#[serde(rename = "auth_info")]
Info(Info),
/// Reports an error occurrred during authentication
#[serde(rename = "auth_error")]
Error(Error),
/// Indicates that the authentication of all methods is finished
#[serde(rename = "auth_finished")]
Finished,
}
/// Represents the beginning of the authentication procedure
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Initialization {
/// Available methods to use for authentication
pub methods: Vec<String>,
}
/// Represents the start of authentication for some method
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct StartMethod {
pub method: String,
}
/// Represents a challenge comprising a series of questions to be presented
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Challenge {
pub questions: Vec<Question>,
pub options: HashMap<String, String>,
}
/// Represents an ask to verify some information
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Verification {
pub kind: VerificationKind,
pub text: String,
}
/// Represents some information to be presented related to authentication
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Info {
pub text: String,
}
/// Represents authentication messages that are responses to authenticator requests such
/// as answers to challenges or verifying information
#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum AuthenticationResponse {
/// Contains response to initialization, providing details about which methods to use
#[serde(rename = "auth_initialization_response")]
Initialization(InitializationResponse),
/// Contains answers to challenge request
#[serde(rename = "auth_challenge_response")]
Challenge(ChallengeResponse),
/// Contains response to a verification request
#[serde(rename = "auth_verification_response")]
Verification(VerificationResponse),
}
/// Represents a response to initialization to specify which authentication methods to pursue
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct InitializationResponse {
/// Methods to use (in order as provided)
pub methods: Vec<String>,
}
/// Represents the answers to a previously-asked challenge associated with authentication
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ChallengeResponse {
/// Answers to challenge questions (in order relative to questions)
pub answers: Vec<String>,
}
/// Represents the answer to a previously-asked verification associated with authentication
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct VerificationResponse {
/// Whether or not the verification was deemed valid
pub valid: bool,
}
/// Represents the type of verification being requested
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum VerificationKind {
/// An ask to verify the host such as with SSH
#[display(fmt = "host")]
Host,
/// When the verification is unknown (happens when other side is unaware of the kind)
#[display(fmt = "unknown")]
#[serde(other)]
Unknown,
}
impl VerificationKind {
/// Returns all variants except "unknown"
pub const fn known_variants() -> &'static [Self] {
&[Self::Host]
}
}
/// Represents a single question in a challenge associated with authentication
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Question {
/// Label associated with the question for more programmatic usage
pub label: String,
/// The text of the question (used for display purposes)
pub text: String,
/// Any options information specific to a particular auth domain
/// such as including a username and instructions for SSH authentication
pub options: HashMap<String, String>,
}
impl Question {
/// Creates a new question without any options data using `text` for both label and text
pub fn new(text: impl Into<String>) -> Self {
let text = text.into();
Self {
label: text.clone(),
text,
options: HashMap::new(),
}
}
}
/// Represents some error that occurred during authentication
#[derive(Clone, Debug, Display, Error, PartialEq, Eq, Serialize, Deserialize)]
#[display(fmt = "{}: {}", kind, text)]
pub struct Error {
/// Represents the kind of error
pub kind: ErrorKind,
/// Description of the error
pub text: String,
}
impl Error {
/// Creates a fatal error
pub fn fatal(text: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Fatal,
text: text.into(),
}
}
/// Creates a non-fatal error
pub fn non_fatal(text: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Error,
text: text.into(),
}
}
/// Returns true if error represents a fatal error, meaning that there is no recovery possible
/// from this error
pub fn is_fatal(&self) -> bool {
self.kind.is_fatal()
}
/// Converts the error into a [`std::io::Error`] representing permission denied
pub fn into_io_permission_denied(self) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::PermissionDenied, self)
}
}
/// Represents the type of error encountered during authentication
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ErrorKind {
/// Error is unrecoverable
Fatal,
/// Error is recoverable
Error,
}
impl ErrorKind {
/// Returns true if error kind represents a fatal error, meaning that there is no recovery
/// possible from this error
pub fn is_fatal(self) -> bool {
matches!(self, Self::Fatal)
}
}

File diff suppressed because it is too large Load Diff

@ -1,4 +1,4 @@
use crate::serde_str::{deserialize_from_str, serialize_to_str};
use super::utils::{deserialize_from_str, serialize_to_str};
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
use std::{fmt, hash::Hash, str::FromStr};
@ -38,17 +38,8 @@ pub struct Destination {
}
impl Destination {
/// Returns true if destination represents a distant server
pub fn is_distant(&self) -> bool {
self.scheme_eq("distant")
}
/// Returns true if destination represents an ssh server
pub fn is_ssh(&self) -> bool {
self.scheme_eq("ssh")
}
fn scheme_eq(&self, s: &str) -> bool {
/// Returns true if the destination's scheme represents the specified (case-insensitive).
pub fn scheme_eq(&self, s: &str) -> bool {
match self.scheme.as_ref() {
Some(scheme) => scheme.eq_ignore_ascii_case(s),
None => false,
@ -58,13 +49,13 @@ impl Destination {
impl AsRef<Destination> for &Destination {
fn as_ref(&self) -> &Destination {
*self
self
}
}
impl AsMut<Destination> for &mut Destination {
fn as_mut(&mut self) -> &mut Destination {
*self
self
}
}

@ -1,4 +1,4 @@
use crate::serde_str::{deserialize_from_str, serialize_to_str};
use super::{deserialize_from_str, serialize_to_str};
use derive_more::{Display, Error, From};
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
use std::{
@ -109,7 +109,7 @@ impl FromStr for Host {
/// ### Examples
///
/// ```
/// # use distant_core::Host;
/// # use distant_net::common::Host;
/// # use std::net::{Ipv4Addr, Ipv6Addr};
/// // IPv4 address
/// assert_eq!("127.0.0.1".parse(), Ok(Host::Ipv4(Ipv4Addr::new(127, 0, 0, 1))));

@ -1,4 +1,4 @@
use crate::Listener;
use super::Listener;
use async_trait::async_trait;
use std::io;

@ -1,4 +1,4 @@
use crate::Listener;
use super::Listener;
use async_trait::async_trait;
use derive_more::From;
use std::io;

@ -1,4 +1,4 @@
use crate::Listener;
use super::Listener;
use async_trait::async_trait;
use derive_more::From;
use std::io;
@ -48,9 +48,10 @@ impl<T: Send> Listener for OneshotListener<T> {
#[cfg(test)]
mod tests {
use super::*;
use test_log::test;
use tokio::task::JoinHandle;
#[tokio::test]
#[test(tokio::test)]
async fn from_value_should_return_value_on_first_call_to_accept() {
let mut listener = OneshotListener::from_value("hello world");
assert_eq!(listener.accept().await.unwrap(), "hello world");
@ -60,7 +61,7 @@ mod tests {
);
}
#[tokio::test]
#[test(tokio::test)]
async fn channel_should_return_a_oneshot_sender_to_feed_first_call_to_accept() {
let (tx, mut listener) = OneshotListener::channel();
let accept_task: JoinHandle<(io::Result<&str>, io::Result<&str>)> =

@ -1,4 +1,5 @@
use crate::{Listener, PortRange, TcpTransport};
use super::Listener;
use crate::common::{PortRange, TcpTransport};
use async_trait::async_trait;
use std::{fmt, io, net::IpAddr};
use tokio::net::TcpListener as TokioTcpListener;
@ -64,14 +65,12 @@ impl Listener for TcpListener {
#[cfg(test)]
mod tests {
use super::*;
use crate::common::TransportExt;
use std::net::{Ipv6Addr, SocketAddr};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::oneshot,
task::JoinHandle,
};
use test_log::test;
use tokio::{sync::oneshot, task::JoinHandle};
#[tokio::test]
#[test(tokio::test)]
async fn should_fail_to_bind_if_port_already_bound() {
let addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
let port = 0; // Ephemeral port
@ -91,8 +90,8 @@ mod tests {
));
}
#[tokio::test]
async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() {
#[test(tokio::test)]
async fn should_be_able_to_receive_connections_and_read_and_write_data_with_them() {
let (tx, rx) = oneshot::channel();
// Spawn a task that will wait for two connections and then
@ -109,7 +108,7 @@ mod tests {
.map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string()))?;
// Get first connection
let mut conn_1 = listener.accept().await?;
let conn_1 = listener.accept().await?;
// Send some data to the first connection (12 bytes)
conn_1.write_all(b"hello conn 1").await?;
@ -120,7 +119,7 @@ mod tests {
assert_eq!(&buf, b"hello server 1");
// Get second connection
let mut conn_2 = listener.accept().await?;
let conn_2 = listener.accept().await?;
// Send some data on to second connection (12 bytes)
conn_2.write_all(b"hello conn 2").await?;
@ -139,7 +138,7 @@ mod tests {
// Connect to the listener twice, sending some bytes and receiving some bytes from each
let mut buf: [u8; 12] = [0; 12];
let mut conn = TcpTransport::connect(&address)
let conn = TcpTransport::connect(&address)
.await
.expect("Conn 1 failed to connect");
conn.write_all(b"hello server 1")
@ -150,7 +149,7 @@ mod tests {
.expect("Conn 1 failed to read");
assert_eq!(&buf, b"hello conn 1");
let mut conn = TcpTransport::connect(&address)
let conn = TcpTransport::connect(&address)
.await
.expect("Conn 2 failed to connect");
conn.write_all(b"hello server 2")

@ -1,4 +1,5 @@
use crate::{Listener, UnixSocketTransport};
use super::Listener;
use crate::common::UnixSocketTransport;
use async_trait::async_trait;
use std::{
fmt, io,
@ -94,14 +95,12 @@ impl Listener for UnixSocketListener {
#[cfg(test)]
mod tests {
use super::*;
use crate::common::TransportExt;
use tempfile::NamedTempFile;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::oneshot,
task::JoinHandle,
};
use test_log::test;
use tokio::{sync::oneshot, task::JoinHandle};
#[tokio::test]
#[test(tokio::test)]
async fn should_succeed_to_bind_if_file_exists_at_path_but_nothing_listening() {
// Generate a socket path
let path = NamedTempFile::new()
@ -114,7 +113,7 @@ mod tests {
.expect("Unexpectedly failed to bind to existing file");
}
#[tokio::test]
#[test(tokio::test)]
async fn should_fail_to_bind_if_socket_already_bound() {
// Generate a socket path and delete the file after
let path = NamedTempFile::new()
@ -133,8 +132,8 @@ mod tests {
.expect_err("Unexpectedly succeeded in binding to same socket");
}
#[tokio::test]
async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() {
#[test(tokio::test)]
async fn should_be_able_to_receive_connections_and_read_and_write_data_with_them() {
let (tx, rx) = oneshot::channel();
// Spawn a task that will wait for two connections and then
@ -154,7 +153,7 @@ mod tests {
.map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?;
// Get first connection
let mut conn_1 = listener.accept().await?;
let conn_1 = listener.accept().await?;
// Send some data to the first connection (12 bytes)
conn_1.write_all(b"hello conn 1").await?;
@ -165,7 +164,7 @@ mod tests {
assert_eq!(&buf, b"hello server 1");
// Get second connection
let mut conn_2 = listener.accept().await?;
let conn_2 = listener.accept().await?;
// Send some data on to second connection (12 bytes)
conn_2.write_all(b"hello conn 2").await?;
@ -184,7 +183,7 @@ mod tests {
// Connect to the listener twice, sending some bytes and receiving some bytes from each
let mut buf: [u8; 12] = [0; 12];
let mut conn = UnixSocketTransport::connect(&path)
let conn = UnixSocketTransport::connect(&path)
.await
.expect("Conn 1 failed to connect");
conn.write_all(b"hello server 1")
@ -195,7 +194,7 @@ mod tests {
.expect("Conn 1 failed to read");
assert_eq!(&buf, b"hello conn 1");
let mut conn = UnixSocketTransport::connect(&path)
let conn = UnixSocketTransport::connect(&path)
.await
.expect("Conn 2 failed to connect");
conn.write_all(b"hello server 2")

@ -1,4 +1,5 @@
use crate::{Listener, NamedPipe, WindowsPipeTransport};
use super::Listener;
use crate::common::{NamedPipe, WindowsPipeTransport};
use async_trait::async_trait;
use std::{
ffi::{OsStr, OsString},
@ -66,13 +67,11 @@ impl Listener for WindowsPipeListener {
#[cfg(test)]
mod tests {
use super::*;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::oneshot,
task::JoinHandle,
};
use crate::common::TransportExt;
use test_log::test;
use tokio::{sync::oneshot, task::JoinHandle};
#[tokio::test]
#[test(tokio::test)]
async fn should_fail_to_bind_if_pipe_already_bound() {
// Generate a pipe name
let name = format!("test_pipe_{}", rand::random::<usize>());
@ -86,8 +85,8 @@ mod tests {
.expect_err("Unexpectedly succeeded in binding to same pipe");
}
#[tokio::test]
async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() {
#[test(tokio::test)]
async fn should_be_able_to_receive_connections_and_read_and_write_data_with_them() {
let (tx, rx) = oneshot::channel();
// Spawn a task that will wait for two connections and then
@ -104,7 +103,7 @@ mod tests {
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
// Get first connection
let mut conn_1 = listener.accept().await?;
let conn_1 = listener.accept().await?;
// Send some data to the first connection (12 bytes)
conn_1.write_all(b"hello conn 1").await?;
@ -115,7 +114,7 @@ mod tests {
assert_eq!(&buf, b"hello server 1");
// Get second connection
let mut conn_2 = listener.accept().await?;
let conn_2 = listener.accept().await?;
// Send some data on to second connection (12 bytes)
conn_2.write_all(b"hello conn 2").await?;
@ -134,7 +133,7 @@ mod tests {
// Connect to the listener twice, sending some bytes and receiving some bytes from each
let mut buf: [u8; 12] = [0; 12];
let mut conn = WindowsPipeTransport::connect_local(&name)
let conn = WindowsPipeTransport::connect_local(&name)
.await
.expect("Conn 1 failed to connect");
conn.write_all(b"hello server 1")
@ -145,7 +144,7 @@ mod tests {
.expect("Conn 1 failed to read");
assert_eq!(&buf, b"hello conn 1");
let mut conn = WindowsPipeTransport::connect_local(&name)
let conn = WindowsPipeTransport::connect_local(&name)
.await
.expect("Conn 2 failed to connect");
conn.write_all(b"hello server 2")

@ -1,4 +1,4 @@
use crate::serde_str::{deserialize_from_str, serialize_to_str};
use crate::common::utils::{deserialize_from_str, serialize_to_str};
use derive_more::{Display, Error, From, IntoIterator};
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
use std::{
@ -198,6 +198,13 @@ impl<'de> Deserialize<'de> for Map {
}
}
/// Generates a new [`Map`] of key/value pairs based on literals.
///
/// ```
/// use distant_net::map;
///
/// let _map = map!("key" -> "value", "key2" -> "value2");
/// ```
#[macro_export]
macro_rules! map {
($($key:literal -> $value:literal),* $(,)?) => {{
@ -207,7 +214,7 @@ macro_rules! map {
_map.insert($key.to_string(), $value.to_string());
)*
$crate::Map::from(_map)
$crate::common::Map::from(_map)
}};
}

@ -0,0 +1,628 @@
/// Represents a generic id type
pub type Id = String;
mod request;
mod response;
pub use request::*;
pub use response::*;
#[derive(Clone, Debug, PartialEq, Eq)]
enum MsgPackStrParseError {
InvalidFormat,
Utf8Error(std::str::Utf8Error),
}
/// Writes the given str to the end of `buf` as the str's msgpack representation.
///
/// # Panics
///
/// Panics if `s.len() >= 2 ^ 32` as the maximum str length for a msgpack str is `(2 ^ 32) - 1`.
fn write_str_msg_pack(s: &str, buf: &mut Vec<u8>) {
assert!(
s.len() < 2usize.pow(32),
"str cannot be longer than (2^32)-1 bytes"
);
if s.len() < 32 {
buf.push(s.len() as u8 | 0b10100000);
} else if s.len() < 2usize.pow(8) {
buf.push(0xd9);
buf.push(s.len() as u8);
} else if s.len() < 2usize.pow(16) {
buf.push(0xda);
for b in (s.len() as u16).to_be_bytes() {
buf.push(b);
}
} else {
buf.push(0xdb);
for b in (s.len() as u32).to_be_bytes() {
buf.push(b);
}
}
buf.extend_from_slice(s.as_bytes());
}
/// Parse msgpack str, returning remaining bytes and str on success, or error on failure.
fn parse_msg_pack_str(input: &[u8]) -> Result<(&[u8], &str), MsgPackStrParseError> {
let ilen = input.len();
if ilen == 0 {
return Err(MsgPackStrParseError::InvalidFormat);
}
// * fixstr using 0xa0 - 0xbf to mark the start of the str where < 32 bytes
// * str 8 (0xd9) if up to (2^8)-1 bytes, using next byte for len
// * str 16 (0xda) if up to (2^16)-1 bytes, using next two bytes for len
// * str 32 (0xdb) if up to (2^32)-1 bytes, using next four bytes for len
let (input, len): (&[u8], usize) = if input[0] >= 0xa0 && input[0] <= 0xbf {
(&input[1..], (input[0] & 0b00011111).into())
} else if input[0] == 0xd9 && ilen > 2 {
(&input[2..], input[1].into())
} else if input[0] == 0xda && ilen > 3 {
(&input[3..], u16::from_be_bytes([input[1], input[2]]).into())
} else if input[0] == 0xdb && ilen > 5 {
(
&input[5..],
u32::from_be_bytes([input[1], input[2], input[3], input[4]])
.try_into()
.unwrap(),
)
} else {
return Err(MsgPackStrParseError::InvalidFormat);
};
let s = match std::str::from_utf8(&input[..len]) {
Ok(s) => s,
Err(x) => return Err(MsgPackStrParseError::Utf8Error(x)),
};
Ok((&input[len..], s))
}
#[cfg(test)]
mod tests {
use super::*;
mod write_str_msg_pack {
use super::*;
#[test]
fn should_support_fixstr() {
// 0-byte str
let mut buf = Vec::new();
write_str_msg_pack("", &mut buf);
assert_eq!(buf, &[0xa0]);
// 1-byte str
let mut buf = Vec::new();
write_str_msg_pack("a", &mut buf);
assert_eq!(buf, &[0xa1, b'a']);
// 2-byte str
let mut buf = Vec::new();
write_str_msg_pack("ab", &mut buf);
assert_eq!(buf, &[0xa2, b'a', b'b']);
// 3-byte str
let mut buf = Vec::new();
write_str_msg_pack("abc", &mut buf);
assert_eq!(buf, &[0xa3, b'a', b'b', b'c']);
// 4-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcd", &mut buf);
assert_eq!(buf, &[0xa4, b'a', b'b', b'c', b'd']);
// 5-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcde", &mut buf);
assert_eq!(buf, &[0xa5, b'a', b'b', b'c', b'd', b'e']);
// 6-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdef", &mut buf);
assert_eq!(buf, &[0xa6, b'a', b'b', b'c', b'd', b'e', b'f']);
// 7-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefg", &mut buf);
assert_eq!(buf, &[0xa7, b'a', b'b', b'c', b'd', b'e', b'f', b'g']);
// 8-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefgh", &mut buf);
assert_eq!(buf, &[0xa8, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h']);
// 9-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghi", &mut buf);
assert_eq!(
buf,
&[0xa9, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']
);
// 10-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghij", &mut buf);
assert_eq!(
buf,
&[0xaa, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j']
);
// 11-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijk", &mut buf);
assert_eq!(
buf,
&[0xab, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k']
);
// 12-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijkl", &mut buf);
assert_eq!(
buf,
&[0xac, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l']
);
// 13-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklm", &mut buf);
assert_eq!(
buf,
&[
0xad, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm'
]
);
// 14-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmn", &mut buf);
assert_eq!(
buf,
&[
0xae, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n'
]
);
// 15-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmno", &mut buf);
assert_eq!(
buf,
&[
0xaf, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o'
]
);
// 16-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnop", &mut buf);
assert_eq!(
buf,
&[
0xb0, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p'
]
);
// 17-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopq", &mut buf);
assert_eq!(
buf,
&[
0xb1, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q'
]
);
// 18-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqr", &mut buf);
assert_eq!(
buf,
&[
0xb2, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r'
]
);
// 19-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrs", &mut buf);
assert_eq!(
buf,
&[
0xb3, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's'
]
);
// 20-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrst", &mut buf);
assert_eq!(
buf,
&[
0xb4, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't'
]
);
// 21-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstu", &mut buf);
assert_eq!(
buf,
&[
0xb5, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u'
]
);
// 22-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuv", &mut buf);
assert_eq!(
buf,
&[
0xb6, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v'
]
);
// 23-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuvw", &mut buf);
assert_eq!(
buf,
&[
0xb7, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w'
]
);
// 24-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuvwx", &mut buf);
assert_eq!(
buf,
&[
0xb8, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x'
]
);
// 25-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuvwxy", &mut buf);
assert_eq!(
buf,
&[
0xb9, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y'
]
);
// 26-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz", &mut buf);
assert_eq!(
buf,
&[
0xba, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
b'z'
]
);
// 27-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz0", &mut buf);
assert_eq!(
buf,
&[
0xbb, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
b'z', b'0'
]
);
// 28-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz01", &mut buf);
assert_eq!(
buf,
&[
0xbc, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
b'z', b'0', b'1'
]
);
// 29-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz012", &mut buf);
assert_eq!(
buf,
&[
0xbd, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
b'z', b'0', b'1', b'2'
]
);
// 30-byte str
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz0123", &mut buf);
assert_eq!(
buf,
&[
0xbe, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
b'z', b'0', b'1', b'2', b'3'
]
);
// 31-byte str is maximum len of fixstr
let mut buf = Vec::new();
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz01234", &mut buf);
assert_eq!(
buf,
&[
0xbf, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
b'z', b'0', b'1', b'2', b'3', b'4'
]
);
}
#[test]
fn should_support_str_8() {
let input = "a".repeat(32);
let mut buf = Vec::new();
write_str_msg_pack(&input, &mut buf);
assert_eq!(buf[0], 0xd9);
assert_eq!(buf[1], input.len() as u8);
assert_eq!(&buf[2..], input.as_bytes());
let input = "a".repeat(2usize.pow(8) - 1);
let mut buf = Vec::new();
write_str_msg_pack(&input, &mut buf);
assert_eq!(buf[0], 0xd9);
assert_eq!(buf[1], input.len() as u8);
assert_eq!(&buf[2..], input.as_bytes());
}
#[test]
fn should_support_str_16() {
let input = "a".repeat(2usize.pow(8));
let mut buf = Vec::new();
write_str_msg_pack(&input, &mut buf);
assert_eq!(buf[0], 0xda);
assert_eq!(&buf[1..3], &(input.len() as u16).to_be_bytes());
assert_eq!(&buf[3..], input.as_bytes());
let input = "a".repeat(2usize.pow(16) - 1);
let mut buf = Vec::new();
write_str_msg_pack(&input, &mut buf);
assert_eq!(buf[0], 0xda);
assert_eq!(&buf[1..3], &(input.len() as u16).to_be_bytes());
assert_eq!(&buf[3..], input.as_bytes());
}
#[test]
fn should_support_str_32() {
let input = "a".repeat(2usize.pow(16));
let mut buf = Vec::new();
write_str_msg_pack(&input, &mut buf);
assert_eq!(buf[0], 0xdb);
assert_eq!(&buf[1..5], &(input.len() as u32).to_be_bytes());
assert_eq!(&buf[5..], input.as_bytes());
}
}
mod parse_msg_pack_str {
use super::*;
#[test]
fn should_be_able_to_parse_fixstr() {
// Empty str
let (input, s) = parse_msg_pack_str(&[0xa0]).unwrap();
assert!(input.is_empty());
assert_eq!(s, "");
// Single character
let (input, s) = parse_msg_pack_str(&[0xa1, b'a']).unwrap();
assert!(input.is_empty());
assert_eq!(s, "a");
// 31 byte str
let (input, s) = parse_msg_pack_str(&[
0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
b'a', b'a', b'a', b'a',
])
.unwrap();
assert!(input.is_empty());
assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
// Verify that we only consume up to fixstr length
assert_eq!(parse_msg_pack_str(&[0xa0, b'a']).unwrap().0, b"a");
assert_eq!(
parse_msg_pack_str(&[
0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
b'a', b'a', b'a', b'a', b'a', b'a', b'b'
])
.unwrap()
.0,
b"b"
);
}
#[test]
fn should_be_able_to_parse_str_8() {
// 32 byte str
let (input, s) = parse_msg_pack_str(&[
0xd9, 32, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
b'a', b'a', b'a', b'a', b'a', b'a',
])
.unwrap();
assert!(input.is_empty());
assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
// 2^8 - 1 (255) byte str
let test_str = "a".repeat(2usize.pow(8) - 1);
let mut input = vec![0xd9, 255];
input.extend_from_slice(test_str.as_bytes());
let (input, s) = parse_msg_pack_str(&input).unwrap();
assert!(input.is_empty());
assert_eq!(s, test_str);
// Verify that we only consume up to 2^8 - 1 length
let mut input = vec![0xd9, 255];
input.extend_from_slice(test_str.as_bytes());
input.extend_from_slice(b"hello");
let (input, s) = parse_msg_pack_str(&input).unwrap();
assert_eq!(input, b"hello");
assert_eq!(s, test_str);
}
#[test]
fn should_be_able_to_parse_str_16() {
// 2^8 byte str (256)
let test_str = "a".repeat(2usize.pow(8));
let mut input = vec![0xda, 1, 0];
input.extend_from_slice(test_str.as_bytes());
let (input, s) = parse_msg_pack_str(&input).unwrap();
assert!(input.is_empty());
assert_eq!(s, test_str);
// 2^16 - 1 (65535) byte str
let test_str = "a".repeat(2usize.pow(16) - 1);
let mut input = vec![0xda, 255, 255];
input.extend_from_slice(test_str.as_bytes());
let (input, s) = parse_msg_pack_str(&input).unwrap();
assert!(input.is_empty());
assert_eq!(s, test_str);
// Verify that we only consume up to 2^16 - 1 length
let mut input = vec![0xda, 255, 255];
input.extend_from_slice(test_str.as_bytes());
input.extend_from_slice(b"hello");
let (input, s) = parse_msg_pack_str(&input).unwrap();
assert_eq!(input, b"hello");
assert_eq!(s, test_str);
}
#[test]
fn should_be_able_to_parse_str_32() {
// 2^16 byte str
let test_str = "a".repeat(2usize.pow(16));
let mut input = vec![0xdb, 0, 1, 0, 0];
input.extend_from_slice(test_str.as_bytes());
let (input, s) = parse_msg_pack_str(&input).unwrap();
assert!(input.is_empty());
assert_eq!(s, test_str);
// NOTE: We are not going to run the below tests, not because they aren't valid but
// because this generates a 4GB str which takes 20+ seconds to run
// 2^32 - 1 byte str (4294967295 bytes)
/* let test_str = "a".repeat(2usize.pow(32) - 1);
let mut input = vec![0xdb, 255, 255, 255, 255];
input.extend_from_slice(test_str.as_bytes());
let (input, s) = parse_msg_pack_str(&input).unwrap();
assert!(input.is_empty());
assert_eq!(s, test_str); */
// Verify that we only consume up to 2^32 - 1 length
/* let mut input = vec![0xdb, 255, 255, 255, 255];
input.extend_from_slice(test_str.as_bytes());
input.extend_from_slice(b"hello");
let (input, s) = parse_msg_pack_str(&input).unwrap();
assert_eq!(input, b"hello");
assert_eq!(s, test_str); */
}
#[test]
fn should_fail_parsing_str_with_invalid_length() {
// Make sure that parse doesn't fail looking for bytes after str 8 len
assert_eq!(
parse_msg_pack_str(&[0xd9]),
Err(MsgPackStrParseError::InvalidFormat)
);
assert_eq!(
parse_msg_pack_str(&[0xd9, 0]),
Err(MsgPackStrParseError::InvalidFormat)
);
// Make sure that parse doesn't fail looking for bytes after str 16 len
assert_eq!(
parse_msg_pack_str(&[0xda]),
Err(MsgPackStrParseError::InvalidFormat)
);
assert_eq!(
parse_msg_pack_str(&[0xda, 0]),
Err(MsgPackStrParseError::InvalidFormat)
);
assert_eq!(
parse_msg_pack_str(&[0xda, 0, 0]),
Err(MsgPackStrParseError::InvalidFormat)
);
// Make sure that parse doesn't fail looking for bytes after str 32 len
assert_eq!(
parse_msg_pack_str(&[0xdb]),
Err(MsgPackStrParseError::InvalidFormat)
);
assert_eq!(
parse_msg_pack_str(&[0xdb, 0]),
Err(MsgPackStrParseError::InvalidFormat)
);
assert_eq!(
parse_msg_pack_str(&[0xdb, 0, 0]),
Err(MsgPackStrParseError::InvalidFormat)
);
assert_eq!(
parse_msg_pack_str(&[0xdb, 0, 0, 0]),
Err(MsgPackStrParseError::InvalidFormat)
);
assert_eq!(
parse_msg_pack_str(&[0xdb, 0, 0, 0, 0]),
Err(MsgPackStrParseError::InvalidFormat)
);
}
#[test]
fn should_fail_parsing_other_types() {
assert_eq!(
parse_msg_pack_str(&[0xc3]), // Boolean (true)
Err(MsgPackStrParseError::InvalidFormat)
);
}
#[test]
fn should_fail_if_empty_input() {
assert_eq!(
parse_msg_pack_str(&[]),
Err(MsgPackStrParseError::InvalidFormat)
);
}
#[test]
fn should_fail_if_str_is_not_utf8() {
assert!(matches!(
parse_msg_pack_str(&[0xa4, 0, 159, 146, 150]),
Err(MsgPackStrParseError::Utf8Error(_))
));
}
}
}

@ -1,5 +1,6 @@
use super::{parse_msg_pack_str, Id};
use crate::utils;
use super::{parse_msg_pack_str, write_str_msg_pack, Id};
use crate::common::utils;
use derive_more::{Display, Error};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{borrow::Cow, io, str};
@ -37,6 +38,14 @@ where
pub fn to_payload_vec(&self) -> io::Result<Vec<u8>> {
utils::serialize_to_vec(&self.payload)
}
/// Attempts to convert a typed request to an untyped request
pub fn to_untyped_request(&self) -> io::Result<UntypedRequest> {
Ok(UntypedRequest {
id: Cow::Borrowed(&self.id),
payload: Cow::Owned(self.to_payload_vec()?),
})
}
}
impl<T> Request<T>
@ -63,7 +72,7 @@ impl<T> From<T> for Request<T> {
}
/// Error encountered when attempting to parse bytes as an untyped request
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq, Hash)]
pub enum UntypedRequestParseError {
/// When the bytes do not represent a request
WrongType,
@ -119,6 +128,24 @@ impl<'a> UntypedRequest<'a> {
}
}
/// Updates the id of the request to the given `id`.
pub fn set_id(&mut self, id: impl Into<String>) {
self.id = Cow::Owned(id.into());
}
/// Allocates a new collection of bytes representing the request.
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = vec![0x82];
write_str_msg_pack("id", &mut bytes);
write_str_msg_pack(&self.id, &mut bytes);
write_str_msg_pack("payload", &mut bytes);
bytes.extend_from_slice(&self.payload);
bytes
}
/// Parses a collection of bytes, returning a partial request if it can be potentially
/// represented as a [`Request`] depending on the payload, or the original bytes if it does not
/// represent a [`Request`]
@ -169,6 +196,7 @@ impl<'a> UntypedRequest<'a> {
#[cfg(test)]
mod tests {
use super::*;
use test_log::test;
const TRUE_BYTE: u8 = 0xc3;
const NEVER_USED_BYTE: u8 = 0xc1;
@ -182,6 +210,19 @@ mod tests {
/// fixstr of 4 bytes with str "test"
const TEST_STR_BYTES: &[u8] = &[0xa4, 0x74, 0x65, 0x73, 0x74];
#[test]
fn untyped_request_should_support_converting_to_bytes() {
let bytes = Request {
id: "some id".to_string(),
payload: true,
}
.to_vec()
.unwrap();
let untyped_request = UntypedRequest::from_slice(&bytes).unwrap();
assert_eq!(untyped_request.to_bytes(), bytes);
}
#[test]
fn untyped_request_should_support_parsing_from_request_bytes_with_valid_payload() {
let bytes = Request {

@ -1,5 +1,6 @@
use super::{parse_msg_pack_str, Id};
use crate::utils;
use super::{parse_msg_pack_str, write_str_msg_pack, Id};
use crate::common::utils;
use derive_more::{Display, Error};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{borrow::Cow, io};
@ -41,6 +42,15 @@ where
pub fn to_payload_vec(&self) -> io::Result<Vec<u8>> {
utils::serialize_to_vec(&self.payload)
}
/// Attempts to convert a typed response to an untyped response
pub fn to_untyped_response(&self) -> io::Result<UntypedResponse> {
Ok(UntypedResponse {
id: Cow::Borrowed(&self.id),
origin_id: Cow::Borrowed(&self.origin_id),
payload: Cow::Owned(self.to_payload_vec()?),
})
}
}
impl<T> Response<T>
@ -61,7 +71,7 @@ impl<T: schemars::JsonSchema> Response<T> {
}
/// Error encountered when attempting to parse bytes as an untyped response
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq, Hash)]
pub enum UntypedResponseParseError {
/// When the bytes do not represent a response
WrongType,
@ -88,7 +98,7 @@ pub struct UntypedResponse<'a> {
impl<'a> UntypedResponse<'a> {
/// Attempts to convert an untyped request to a typed request
pub fn to_typed_request<T: DeserializeOwned>(&self) -> io::Result<Response<T>> {
pub fn to_typed_response<T: DeserializeOwned>(&self) -> io::Result<Response<T>> {
Ok(Response {
id: self.id.to_string(),
origin_id: self.origin_id.to_string(),
@ -132,9 +142,35 @@ impl<'a> UntypedResponse<'a> {
}
}
/// Updates the id of the response to the given `id`.
pub fn set_id(&mut self, id: impl Into<String>) {
self.id = Cow::Owned(id.into());
}
/// Updates the origin id of the response to the given `origin_id`.
pub fn set_origin_id(&mut self, origin_id: impl Into<String>) {
self.origin_id = Cow::Owned(origin_id.into());
}
/// Allocates a new collection of bytes representing the response.
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = vec![0x83];
write_str_msg_pack("id", &mut bytes);
write_str_msg_pack(&self.id, &mut bytes);
write_str_msg_pack("origin_id", &mut bytes);
write_str_msg_pack(&self.origin_id, &mut bytes);
write_str_msg_pack("payload", &mut bytes);
bytes.extend_from_slice(&self.payload);
bytes
}
/// Parses a collection of bytes, returning an untyped response if it can be potentially
/// represented as a [`Response`] depending on the payload, or the original bytes if it does not
/// represent a [`Response`]
/// represent a [`Response`].
///
/// NOTE: This supports parsing an invalid response where the payload would not properly
/// deserialize, but the bytes themselves represent a complete response of some kind.
@ -198,6 +234,7 @@ impl<'a> UntypedResponse<'a> {
#[cfg(test)]
mod tests {
use super::*;
use test_log::test;
const TRUE_BYTE: u8 = 0xc3;
const NEVER_USED_BYTE: u8 = 0xc1;
@ -215,6 +252,20 @@ mod tests {
/// fixstr of 4 bytes with str "test"
const TEST_STR_BYTES: &[u8] = &[0xa4, 0x74, 0x65, 0x73, 0x74];
#[test]
fn untyped_response_should_support_converting_to_bytes() {
let bytes = Response {
id: "some id".to_string(),
origin_id: "some origin id".to_string(),
payload: true,
}
.to_vec()
.unwrap();
let untyped_response = UntypedResponse::from_slice(&bytes).unwrap();
assert_eq!(untyped_response.to_bytes(), bytes);
}
#[test]
fn untyped_response_should_support_parsing_from_response_bytes_with_valid_payload() {
let bytes = Response {

@ -0,0 +1,629 @@
use async_trait::async_trait;
use std::{io, time::Duration};
mod framed;
pub use framed::*;
mod inmemory;
pub use inmemory::*;
mod tcp;
pub use tcp::*;
#[cfg(test)]
mod test;
#[cfg(test)]
pub use test::*;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
pub use unix::*;
#[cfg(windows)]
mod windows;
#[cfg(windows)]
pub use windows::*;
pub use tokio::io::{Interest, Ready};
/// Duration to wait after WouldBlock received during looping operations like `read_exact`.
const SLEEP_DURATION: Duration = Duration::from_millis(1);
/// Interface representing a connection that is reconnectable.
#[async_trait]
pub trait Reconnectable {
/// Attempts to reconnect an already-established connection.
async fn reconnect(&mut self) -> io::Result<()>;
}
/// Interface representing a transport of raw bytes into and out of the system.
#[async_trait]
pub trait Transport: Reconnectable + Send + Sync {
/// Tries to read data from the transport into the provided buffer, returning how many bytes
/// were read.
///
/// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
/// is not ready to read data.
///
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize>;
/// Try to write a buffer to the transport, returning how many bytes were written.
///
/// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
/// is not ready to write data.
///
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
fn try_write(&self, buf: &[u8]) -> io::Result<usize>;
/// Waits for the transport to be ready based on the given interest, returning the ready
/// status.
async fn ready(&self, interest: Interest) -> io::Result<Ready>;
}
#[async_trait]
impl Transport for Box<dyn Transport> {
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
Transport::try_read(AsRef::as_ref(self), buf)
}
fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
Transport::try_write(AsRef::as_ref(self), buf)
}
async fn ready(&self, interest: Interest) -> io::Result<Ready> {
Transport::ready(AsRef::as_ref(self), interest).await
}
}
#[async_trait]
impl Reconnectable for Box<dyn Transport> {
async fn reconnect(&mut self) -> io::Result<()> {
Reconnectable::reconnect(AsMut::as_mut(self)).await
}
}
#[async_trait]
pub trait TransportExt {
/// Waits for the transport to be readable to follow up with `try_read`.
async fn readable(&self) -> io::Result<()>;
/// Waits for the transport to be writeable to follow up with `try_write`.
async fn writeable(&self) -> io::Result<()>;
/// Waits for the transport to be either readable or writeable.
async fn readable_or_writeable(&self) -> io::Result<()>;
/// Reads exactly `n` bytes where `n` is the length of `buf` by continuing to call [`try_read`]
/// until completed. Calls to [`readable`] are made to ensure the transport is ready. Returns
/// the total bytes read.
///
/// [`try_read`]: Transport::try_read
/// [`readable`]: Transport::readable
async fn read_exact(&self, buf: &mut [u8]) -> io::Result<usize>;
/// Reads all bytes until EOF in this source, placing them into `buf`.
///
/// All bytes read from this source will be appended to the specified buffer `buf`. This
/// function will continuously call [`try_read`] to append more data to `buf` until
/// [`try_read`] returns either [`Ok(0)`] or an error that is neither [`Interrupted`] or
/// [`WouldBlock`].
///
/// If successful, this function will return the total number of bytes read.
///
/// ### Errors
///
/// If this function encounters an error of the kind [`Interrupted`] or [`WouldBlock`], then
/// the error is ignored and the operation will continue.
///
/// If any other read error is encountered then this function immediately returns. Any bytes
/// which have already been read will be appended to `buf`.
///
/// [`Ok(0)`]: Ok
/// [`try_read`]: Transport::try_read
/// [`readable`]: Transport::readable
async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize>;
/// Reads all bytes until EOF in this source, placing them into `buf`.
///
/// If successful, this function will return the total number of bytes read.
///
/// ### Errors
///
/// If the data in this stream is *not* valid UTF-8 then an error is returned and `buf` is
/// unchanged.
///
/// See [`read_to_end`] for other error semantics.
///
/// [`Ok(0)`]: Ok
/// [`try_read`]: Transport::try_read
/// [`readable`]: Transport::readable
/// [`read_to_end`]: TransportExt::read_to_end
async fn read_to_string(&self, buf: &mut String) -> io::Result<usize>;
/// Writes all of `buf` by continuing to call [`try_write`] until completed. Calls to
/// [`writeable`] are made to ensure the transport is ready.
///
/// [`try_write`]: Transport::try_write
/// [`writable`]: Transport::writable
async fn write_all(&self, buf: &[u8]) -> io::Result<()>;
}
#[async_trait]
impl<T: Transport> TransportExt for T {
async fn readable(&self) -> io::Result<()> {
self.ready(Interest::READABLE).await?;
Ok(())
}
async fn writeable(&self) -> io::Result<()> {
self.ready(Interest::WRITABLE).await?;
Ok(())
}
async fn readable_or_writeable(&self) -> io::Result<()> {
self.ready(Interest::READABLE | Interest::WRITABLE).await?;
Ok(())
}
async fn read_exact(&self, buf: &mut [u8]) -> io::Result<usize> {
let mut i = 0;
while i < buf.len() {
self.readable().await?;
match self.try_read(&mut buf[i..]) {
// If we get 0 bytes read, this usually means that the underlying reader
// has closed, so we will return an EOF error to reflect that
//
// NOTE: `try_read` can also return 0 if the buf len is zero, but because we check
// that our index is < len, the situation where we call try_read with a buf
// of len 0 will never happen
Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
Ok(n) => i += n,
// Because we are using `try_read`, it can be possible for it to return
// WouldBlock; so, if we encounter that then we just wait for next readable
Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
// NOTE: We sleep for a little bit before trying again to avoid pegging CPU
tokio::time::sleep(SLEEP_DURATION).await
}
Err(x) => return Err(x),
}
}
Ok(i)
}
async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
let mut i = 0;
let mut tmp = [0u8; 1024];
loop {
self.readable().await?;
match self.try_read(&mut tmp) {
Ok(0) => return Ok(i),
Ok(n) => {
buf.extend_from_slice(&tmp[..n]);
i += n;
}
Err(x)
if x.kind() == io::ErrorKind::WouldBlock
|| x.kind() == io::ErrorKind::Interrupted =>
{
// NOTE: We sleep for a little bit before trying again to avoid pegging CPU
tokio::time::sleep(SLEEP_DURATION).await
}
Err(x) => return Err(x),
}
}
}
async fn read_to_string(&self, buf: &mut String) -> io::Result<usize> {
let mut tmp = Vec::new();
let n = self.read_to_end(&mut tmp).await?;
buf.push_str(
&String::from_utf8(tmp).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?,
);
Ok(n)
}
async fn write_all(&self, buf: &[u8]) -> io::Result<()> {
let mut i = 0;
while i < buf.len() {
self.writeable().await?;
match self.try_write(&buf[i..]) {
// If we get 0 bytes written, this usually means that the underlying writer
// has closed, so we will return a write zero error to reflect that
//
// NOTE: `try_write` can also return 0 if the buf len is zero, but because we check
// that our index is < len, the situation where we call try_write with a buf
// of len 0 will never happen
Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)),
Ok(n) => i += n,
// Because we are using `try_write`, it can be possible for it to return
// WouldBlock; so, if we encounter that then we just wait for next writeable
Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
// NOTE: We sleep for a little bit before trying again to avoid pegging CPU
tokio::time::sleep(SLEEP_DURATION).await
}
Err(x) => return Err(x),
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_log::test;
#[test(tokio::test)]
async fn read_exact_should_fail_if_try_read_encounters_error_other_than_would_block() {
let transport = TestTransport {
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 1];
assert_eq!(
transport.read_exact(&mut buf).await.unwrap_err().kind(),
io::ErrorKind::NotConnected
);
}
#[test(tokio::test)]
async fn read_exact_should_fail_if_try_read_returns_0_before_necessary_bytes_read() {
let transport = TestTransport {
f_try_read: Box::new(|_| Ok(0)),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 1];
assert_eq!(
transport.read_exact(&mut buf).await.unwrap_err().kind(),
io::ErrorKind::UnexpectedEof
);
}
#[test(tokio::test)]
async fn read_exact_should_continue_to_call_try_read_until_buffer_is_filled() {
let transport = TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
buf[0] = b'a' + CNT;
CNT += 1;
}
Ok(1)
}),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 3];
assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
assert_eq!(&buf, b"abc");
}
#[test(tokio::test)]
async fn read_exact_should_continue_to_call_try_read_while_it_returns_would_block() {
// Configure `try_read` to alternate between reading a byte and WouldBlock
let transport = TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
buf[0] = b'a' + CNT;
CNT += 1;
if CNT % 2 == 1 {
Ok(1)
} else {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
}
}),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 3];
assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
assert_eq!(&buf, b"ace");
}
#[test(tokio::test)]
async fn read_exact_should_return_0_if_given_a_buffer_of_0_len() {
let transport = TestTransport {
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = [0; 0];
assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 0);
}
#[test(tokio::test)]
async fn read_to_end_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt(
) {
let transport = TestTransport {
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
assert_eq!(
transport
.read_to_end(&mut Vec::new())
.await
.unwrap_err()
.kind(),
io::ErrorKind::NotConnected
);
}
#[test(tokio::test)]
async fn read_to_end_should_read_until_0_bytes_returned_from_try_read() {
let transport = TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
if CNT == 0 {
buf[..5].copy_from_slice(b"hello");
CNT += 1;
Ok(5)
} else {
Ok(0)
}
}
}),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = Vec::new();
assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 5);
assert_eq!(buf, b"hello");
}
#[test(tokio::test)]
async fn read_to_end_should_continue_reading_when_interrupt_or_would_block_encountered() {
let transport = TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
CNT += 1;
if CNT == 1 {
buf[..6].copy_from_slice(b"hello ");
Ok(6)
} else if CNT == 2 {
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else if CNT == 3 {
buf[..5].copy_from_slice(b"world");
Ok(5)
} else if CNT == 4 {
Err(io::Error::from(io::ErrorKind::Interrupted))
} else if CNT == 5 {
buf[..6].copy_from_slice(b", test");
Ok(6)
} else {
Ok(0)
}
}
}),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = Vec::new();
assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 17);
assert_eq!(buf, b"hello world, test");
}
#[test(tokio::test)]
async fn read_to_string_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt(
) {
let transport = TestTransport {
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
assert_eq!(
transport
.read_to_string(&mut String::new())
.await
.unwrap_err()
.kind(),
io::ErrorKind::NotConnected
);
}
#[test(tokio::test)]
async fn read_to_string_should_fail_if_non_utf8_characters_read() {
let transport = TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
if CNT == 0 {
buf[0] = 0;
buf[1] = 159;
buf[2] = 146;
buf[3] = 150;
CNT += 1;
Ok(4)
} else {
Ok(0)
}
}
}),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = String::new();
assert_eq!(
transport.read_to_string(&mut buf).await.unwrap_err().kind(),
io::ErrorKind::InvalidData
);
}
#[test(tokio::test)]
async fn read_to_string_should_read_until_0_bytes_returned_from_try_read() {
let transport = TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
if CNT == 0 {
buf[..5].copy_from_slice(b"hello");
CNT += 1;
Ok(5)
} else {
Ok(0)
}
}
}),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = String::new();
assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 5);
assert_eq!(buf, "hello");
}
#[test(tokio::test)]
async fn read_to_string_should_continue_reading_when_interrupt_or_would_block_encountered() {
let transport = TestTransport {
f_try_read: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
CNT += 1;
if CNT == 1 {
buf[..6].copy_from_slice(b"hello ");
Ok(6)
} else if CNT == 2 {
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else if CNT == 3 {
buf[..5].copy_from_slice(b"world");
Ok(5)
} else if CNT == 4 {
Err(io::Error::from(io::ErrorKind::Interrupted))
} else if CNT == 5 {
buf[..6].copy_from_slice(b", test");
Ok(6)
} else {
Ok(0)
}
}
}),
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
};
let mut buf = String::new();
assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 17);
assert_eq!(buf, "hello world, test");
}
#[test(tokio::test)]
async fn write_all_should_fail_if_try_write_encounters_error_other_than_would_block() {
let transport = TestTransport {
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
assert_eq!(
transport.write_all(b"abc").await.unwrap_err().kind(),
io::ErrorKind::NotConnected
);
}
#[test(tokio::test)]
async fn write_all_should_fail_if_try_write_returns_0_before_all_bytes_written() {
let transport = TestTransport {
f_try_write: Box::new(|_| Ok(0)),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
assert_eq!(
transport.write_all(b"abc").await.unwrap_err().kind(),
io::ErrorKind::WriteZero
);
}
#[test(tokio::test)]
async fn write_all_should_continue_to_call_try_write_until_all_bytes_written() {
// Configure `try_write` to alternate between writing a byte and WouldBlock
let transport = TestTransport {
f_try_write: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
assert_eq!(buf[0], b'a' + CNT);
CNT += 1;
Ok(1)
}
}),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
transport.write_all(b"abc").await.unwrap();
}
#[test(tokio::test)]
async fn write_all_should_continue_to_call_try_write_while_it_returns_would_block() {
// Configure `try_write` to alternate between writing a byte and WouldBlock
let transport = TestTransport {
f_try_write: Box::new(|buf| {
static mut CNT: u8 = 0;
unsafe {
if CNT % 2 == 0 {
assert_eq!(buf[0], b'a' + CNT);
CNT += 1;
Ok(1)
} else {
CNT += 1;
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
}
}),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
transport.write_all(b"ace").await.unwrap();
}
#[test(tokio::test)]
async fn write_all_should_return_immediately_if_given_buffer_of_0_len() {
let transport = TestTransport {
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
};
// No error takes place as we never call try_write
let buf = [0; 0];
transport.write_all(&buf).await.unwrap();
}
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,201 @@
use super::{Frame, OwnedFrame};
use std::collections::VecDeque;
/// Maximum size (in bytes) for saved frames (256MiB)
const MAX_BACKUP_SIZE: usize = 256 * 1024 * 1024;
/// Stores [`Frame`]s for reuse later.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Backup {
/// Maximum size (in bytes) to save frames in case we need to backup them
///
/// NOTE: If 0, no frames will be stored.
max_backup_size: usize,
/// Tracker for the total size (in bytes) of stored frames
current_backup_size: usize,
/// Storage used to hold outgoing frames in case they need to be reused
frames: VecDeque<OwnedFrame>,
/// Counter keeping track of total frames sent
sent_cnt: u64,
/// Counter keeping track of total frames received
received_cnt: u64,
/// Indicates whether the backup is frozen, which indicates that mutations are ignored
frozen: bool,
}
impl Default for Backup {
fn default() -> Self {
Self::new()
}
}
impl Backup {
/// Creates a new, unfrozen backup.
pub fn new() -> Self {
Self {
max_backup_size: MAX_BACKUP_SIZE,
current_backup_size: 0,
frames: VecDeque::new(),
sent_cnt: 0,
received_cnt: 0,
frozen: false,
}
}
/// Clears the backup of any stored data and resets the state to being new.
///
/// ### Note
///
/// Like all other modifications, this will do nothing if the backup is frozen.
pub fn clear(&mut self) {
if !self.frozen {
self.current_backup_size = 0;
self.frames.clear();
self.sent_cnt = 0;
self.received_cnt = 0;
}
}
/// Returns true if the backup is frozen, meaning that modifications will be ignored.
#[inline]
pub fn is_frozen(&self) -> bool {
self.frozen
}
/// Sets the frozen status.
#[inline]
pub fn set_frozen(&mut self, frozen: bool) {
self.frozen = frozen;
}
/// Marks the backup as frozen.
#[inline]
pub fn freeze(&mut self) {
self.frozen = true;
}
/// Marks the backup as no longer frozen.
#[inline]
pub fn unfreeze(&mut self) {
self.frozen = false;
}
/// Sets the maximum size (in bytes) of collective frames stored in case a backup is needed
/// during reconnection. Setting the `size` to 0 will result in no frames being stored.
///
/// ### Note
///
/// Like all other modifications, this will do nothing if the backup is frozen.
pub fn set_max_backup_size(&mut self, size: usize) {
if !self.frozen {
self.max_backup_size = size;
}
}
/// Returns the maximum size (in bytes) of collective frames stored in case a backup is needed
/// during reconnection.
pub fn max_backup_size(&self) -> usize {
self.max_backup_size
}
/// Increments (by 1) the total sent frames.
///
/// ### Note
///
/// Like all other modifications, this will do nothing if the backup is frozen.
pub(crate) fn increment_sent_cnt(&mut self) {
if !self.frozen {
self.sent_cnt += 1;
}
}
/// Returns how many frames have been sent.
pub(crate) fn sent_cnt(&self) -> u64 {
self.sent_cnt
}
/// Increments (by 1) the total received frames.
///
/// ### Note
///
/// Like all other modifications, this will do nothing if the backup is frozen.
pub(super) fn increment_received_cnt(&mut self) {
if !self.frozen {
self.received_cnt += 1;
}
}
/// Returns how many frames have been received.
pub(crate) fn received_cnt(&self) -> u64 {
self.received_cnt
}
/// Sets the total received frames to the specified `cnt`.
///
/// ### Note
///
/// Like all other modifications, this will do nothing if the backup is frozen.
pub(super) fn set_received_cnt(&mut self, cnt: u64) {
if !self.frozen {
self.received_cnt = cnt;
}
}
/// Pushes a new frame to the end of the internal queue.
///
/// ### Note
///
/// Like all other modifications, this will do nothing if the backup is frozen.
pub(crate) fn push_frame(&mut self, frame: Frame) {
if self.max_backup_size > 0 && !self.frozen {
self.current_backup_size += frame.len();
self.frames.push_back(frame.into_owned());
while self.current_backup_size > self.max_backup_size {
match self.frames.pop_front() {
Some(frame) => {
self.current_backup_size -= frame.len();
}
// If we have exhausted all frames, then we have reached
// an internal size of 0 and should exit the loop
None => {
self.current_backup_size = 0;
break;
}
}
}
}
}
/// Returns the total frames being kept for potential reuse.
pub(super) fn frame_cnt(&self) -> usize {
self.frames.len()
}
/// Returns an iterator over the frames contained in the backup.
pub(super) fn frames(&self) -> impl Iterator<Item = &Frame> {
self.frames.iter()
}
/// Truncates the stored frames to be no larger than `size` total frames by popping from the
/// front rather than the back of the list.
///
/// ### Note
///
/// Like all other modifications, this will do nothing if the backup is frozen.
pub(super) fn truncate_front(&mut self, size: usize) {
if !self.frozen {
while self.frames.len() > size {
if let Some(frame) = self.frames.pop_front() {
self.current_backup_size -=
std::cmp::min(frame.len(), self.current_backup_size);
}
}
}
}
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save