Unfinished refactoring to support configuring versions for connecting -- need to find all places where we build client/server and set version to protocol OR net package version for manager

pull/219/head
Chip Senkbeil 10 months ago
parent d5b41916cb
commit cf2c5c700e
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -7,8 +7,9 @@ use distant_core::{
};
use distant_net::auth::{DummyAuthHandler, Verifier};
use distant_net::client::Client;
use distant_net::common::{InmemoryTransport, OneshotListener};
use distant_net::common::{InmemoryTransport, OneshotListener, Version};
use distant_net::server::{Server, ServerRef};
use distant_protocol::PROTOCOL_VERSION;
/// Stands up an inmemory client and server using the given api.
async fn setup(api: impl DistantApi + Send + Sync + 'static) -> (DistantClient, ServerRef) {
@ -17,12 +18,22 @@ async fn setup(api: impl DistantApi + Send + Sync + 'static) -> (DistantClient,
let server = Server::new()
.handler(DistantApiServerHandler::new(api))
.verifier(Verifier::none())
.version(Version::new(
PROTOCOL_VERSION.major,
PROTOCOL_VERSION.minor,
PROTOCOL_VERSION.patch,
))
.start(OneshotListener::from_value(t2))
.expect("Failed to start server");
let client: DistantClient = Client::build()
.auth_handler(DummyAuthHandler)
.connector(t1)
.version(Version::new(
PROTOCOL_VERSION.major,
PROTOCOL_VERSION.minor,
PROTOCOL_VERSION.patch,
))
.connect()
.await
.expect("Failed to connect to server");

@ -20,7 +20,7 @@ pub use windows::*;
use super::ClientConfig;
use crate::client::{Client, UntypedClient};
use crate::common::{Connection, Transport};
use crate::common::{Connection, Transport, Version};
/// Interface that performs the connection to produce a [`Transport`] for use by the [`Client`].
#[async_trait]
@ -46,6 +46,7 @@ pub struct ClientBuilder<H, C> {
connector: C,
config: ClientConfig,
connect_timeout: Option<Duration>,
version: Version,
}
impl<H, C> ClientBuilder<H, C> {
@ -56,6 +57,7 @@ impl<H, C> ClientBuilder<H, C> {
config: self.config,
connector: self.connector,
connect_timeout: self.connect_timeout,
version: self.version,
}
}
@ -66,6 +68,7 @@ impl<H, C> ClientBuilder<H, C> {
config,
connector: self.connector,
connect_timeout: self.connect_timeout,
version: self.version,
}
}
@ -76,6 +79,7 @@ impl<H, C> ClientBuilder<H, C> {
config: self.config,
connector,
connect_timeout: self.connect_timeout,
version: self.version,
}
}
@ -86,6 +90,18 @@ impl<H, C> ClientBuilder<H, C> {
config: self.config,
connector: self.connector,
connect_timeout: connect_timeout.into(),
version: self.version,
}
}
/// Configure the version of the client.
pub fn version(self, version: Version) -> Self {
Self {
auth_handler: self.auth_handler,
config: self.config,
connector: self.connector,
connect_timeout: self.connect_timeout,
version,
}
}
}
@ -97,6 +113,7 @@ impl ClientBuilder<(), ()> {
config: Default::default(),
connector: (),
connect_timeout: None,
version: Default::default(),
}
}
}
@ -119,6 +136,7 @@ where
let auth_handler = self.auth_handler;
let config = self.config;
let connect_timeout = self.connect_timeout;
let version = self.version;
let f = async move {
let transport = match connect_timeout {
@ -128,7 +146,7 @@ where
.and_then(convert::identity)?,
None => self.connector.connect().await?,
};
let connection = Connection::client(transport, auth_handler).await?;
let connection = Connection::client(transport, auth_handler, version).await?;
Ok(UntypedClient::spawn(connection, config))
};

@ -9,6 +9,7 @@ mod packet;
mod port;
mod transport;
pub(crate) mod utils;
mod version;
pub use any::*;
pub(crate) use connection::Connection;
@ -21,3 +22,4 @@ pub use map::*;
pub use packet::*;
pub use port::*;
pub use transport::*;
pub use version::*;

@ -11,6 +11,7 @@ use tokio::sync::oneshot;
use crate::common::InmemoryTransport;
use crate::common::{
Backup, FramedTransport, HeapSecretKey, Keychain, KeychainResult, Reconnectable, Transport,
TransportExt, Version,
};
/// Id of the connection
@ -110,6 +111,19 @@ where
debug!("[Conn {id}] Re-establishing connection");
Reconnectable::reconnect(transport).await?;
// Wait for exactly version bytes (24 where 8 bytes for major, minor, patch)
// but with a reconnect we don't actually validate it because we did that
// the first time we connected
//
// NOTE: We do this with the raw transport and not the framed version!
debug!("[Conn {id}] Waiting for server version");
if transport.as_mut_inner().read_exact(&mut [0u8; 24]).await? != 24 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Wrong version byte len received",
));
}
// Perform a handshake to ensure that the connection is properly established and encrypted
debug!("[Conn {id}] Performing handshake");
transport.client_handshake().await?;
@ -190,13 +204,42 @@ where
/// Transforms a raw [`Transport`] into an established [`Connection`] from the client-side by
/// performing the following:
///
/// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use
/// 2. Authenticates the established connection to ensure it is valid
/// 3. Restores pre-existing state using the provided backup, replaying any missing frames and
/// 1. Performs a version check with the server
/// 2. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use
/// 3. Authenticates the established connection to ensure it is valid
/// 4. Restores pre-existing state using the provided backup, replaying any missing frames and
/// receiving any frames from the other side
pub async fn client<H: AuthHandler + Send>(transport: T, handler: H) -> io::Result<Self> {
pub async fn client<H: AuthHandler + Send>(
transport: T,
handler: H,
version: Version,
) -> io::Result<Self> {
let id: ConnectionId = rand::random();
// Wait for exactly version bytes (24 where 8 bytes for major, minor, patch)
debug!("[Conn {id}] Waiting for server version");
let mut version_bytes = [0u8; 24];
if transport.read_exact(&mut version_bytes).await? != 24 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Wrong version byte len received",
));
}
// Compare versions for compatibility and drop the connection if incompatible
let server_version = Version::from_be_bytes(version_bytes);
debug!(
"[Conn {id}] Checking compatibility between client {version} & server {server_version}"
);
if !version.is_compatible_with(&server_version) {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"Client version {version} is incompatible with server version {server_version}"
),
));
}
// Perform a handshake to ensure that the connection is properly established and encrypted
debug!("[Conn {id}] Performing handshake");
let mut transport: FramedTransport<T> =
@ -238,19 +281,25 @@ where
/// Transforms a raw [`Transport`] into an established [`Connection`] from the server-side by
/// performing the following:
///
/// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use
/// 2. Authenticates the established connection to ensure it is valid by either using the
/// 1. Performs a version check with the client
/// 2. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use
/// 3. Authenticates the established connection to ensure it is valid by either using the
/// given `verifier` or, if working with an existing client connection, will validate an OTP
/// from our database
/// 3. Restores pre-existing state using the provided backup, replaying any missing frames and
/// 4. Restores pre-existing state using the provided backup, replaying any missing frames and
/// receiving any frames from the other side
pub async fn server(
transport: T,
verifier: &Verifier,
keychain: Keychain<oneshot::Receiver<Backup>>,
version: Version,
) -> io::Result<Self> {
let id: ConnectionId = rand::random();
// Write the version as bytes
debug!("[Conn {id}] Sending version {version}");
transport.write_all(&version.to_be_bytes()).await?;
// Perform a handshake to ensure that the connection is properly established and encrypted
debug!("[Conn {id}] Performing handshake");
let mut transport: FramedTransport<T> =
@ -464,6 +513,60 @@ mod tests {
use super::*;
use crate::common::Frame;
macro_rules! server_version {
() => {
Version::new(1, 2, 3)
};
}
macro_rules! send_server_version {
($transport:expr, $version:expr) => {{
($transport)
.as_mut_inner()
.write_all(&$version.to_be_bytes())
.await
.unwrap();
}};
($transport:expr) => {
send_server_version!($transport, server_version!());
};
}
macro_rules! receive_version {
($transport:expr) => {{
let mut bytes = [0u8; 24];
assert_eq!(
($transport)
.as_mut_inner()
.read_exact(&mut bytes)
.await
.unwrap(),
24,
"Wrong version len received"
);
Version::from_be_bytes(bytes)
}};
}
#[test(tokio::test)]
async fn client_should_fail_when_server_sends_incompatible_version() {
let (mut t1, t2) = FramedTransport::pair(100);
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio::spawn(async move {
Connection::client(t2.into_inner(), DummyAuthHandler, Version::new(1, 2, 3))
.await
.unwrap()
});
// Send invalid version to fail the handshake
send_server_version!(t1, Version::new(2, 0, 0));
// Client should fail
task.await.unwrap_err();
}
#[test(tokio::test)]
async fn client_should_fail_if_codec_handshake_fails() {
let (mut t1, t2) = FramedTransport::pair(100);
@ -471,11 +574,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio::spawn(async move {
Connection::client(t2.into_inner(), DummyAuthHandler)
Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
.await
.unwrap()
});
// Send server version for client to confirm
send_server_version!(t1);
// Send garbage to fail the handshake
t1.write_frame(Frame::new(b"invalid")).await.unwrap();
@ -490,11 +596,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio::spawn(async move {
Connection::client(t2.into_inner(), DummyAuthHandler)
Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
.await
.unwrap()
});
// Send server version for client to confirm
send_server_version!(t1);
// Perform first step of connection by establishing the codec
t1.server_handshake().await.unwrap();
@ -519,11 +628,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio::spawn(async move {
Connection::client(t2.into_inner(), DummyAuthHandler)
Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
.await
.unwrap()
});
// Send server version for client to confirm
send_server_version!(t1);
// Perform first step of connection by establishing the codec
t1.server_handshake().await.unwrap();
@ -559,11 +671,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio::spawn(async move {
Connection::client(t2.into_inner(), DummyAuthHandler)
Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
.await
.unwrap()
});
// Send server version for client to confirm
send_server_version!(t1);
// Perform first step of connection by establishing the codec
t1.server_handshake().await.unwrap();
@ -597,11 +712,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio::spawn(async move {
Connection::client(t2.into_inner(), DummyAuthHandler)
Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
.await
.unwrap()
});
// Send server version for client to confirm
send_server_version!(t1);
// Perform first step of connection by establishing the codec
t1.server_handshake().await.unwrap();
@ -629,6 +747,30 @@ mod tests {
assert_eq!(client.otp(), Some(&otp));
}
#[test(tokio::test)]
async fn server_should_fail_if_client_drops_due_to_version() {
let (mut t1, t2) = FramedTransport::pair(100);
let verifier = Verifier::none();
let keychain = Keychain::new();
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
});
// Receive the version from the server
let _ = receive_version!(t1);
// Drop client connection as a result of an "incompatible version"
drop(t1);
// Server should fail
task.await.unwrap_err();
}
#[test(tokio::test)]
async fn server_should_fail_if_codec_handshake_fails() {
let (mut t1, t2) = FramedTransport::pair(100);
@ -638,11 +780,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
});
// Receive the version from the server
let _ = receive_version!(t1);
// Send garbage to fail the handshake
t1.write_frame(Frame::new(b"invalid")).await.unwrap();
@ -659,11 +804,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
});
// Receive the version from the server
let _ = receive_version!(t1);
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
@ -683,11 +831,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
});
// Receive the version from the server
let _ = receive_version!(t1);
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
@ -717,11 +868,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
});
// Receive the version from the server
let _ = receive_version!(t1);
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
@ -750,11 +904,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
});
// Receive the version from the server
let _ = receive_version!(t1);
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
@ -790,11 +947,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
});
// Receive the version from the server
let _ = receive_version!(t1);
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
@ -828,11 +988,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
});
// Receive the version from the server
let _ = receive_version!(t1);
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
@ -866,11 +1029,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
});
// Receive the version from the server
let _ = receive_version!(t1);
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
@ -904,12 +1070,15 @@ mod tests {
let task = tokio::spawn({
let keychain = keychain.clone();
async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
}
});
// Receive the version from the server
let _ = receive_version!(t1);
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
@ -969,12 +1138,15 @@ mod tests {
let task = tokio::spawn({
let keychain = keychain.clone();
async move {
Connection::server(t2.into_inner(), &verifier, keychain)
Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
.await
.unwrap()
}
});
// Receive the version from the server
let _ = receive_version!(t1);
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
@ -1029,13 +1201,13 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock
let task = tokio::spawn(async move {
Connection::server(t2, &verifier, keychain)
Connection::server(t2, &verifier, keychain, server_version!())
.await
.expect("Failed to connect from server")
});
// Perform the client-side of the connection
let mut client = Connection::client(t1, DummyAuthHandler)
let mut client = Connection::client(t1, DummyAuthHandler, server_version!())
.await
.expect("Failed to connect from client");
let mut server = task.await.unwrap();
@ -1063,14 +1235,14 @@ mod tests {
let verifier = Arc::clone(&verifier);
let keychain = keychain.clone();
tokio::spawn(async move {
Connection::server(t2, &verifier, keychain)
Connection::server(t2, &verifier, keychain, server_version!())
.await
.expect("Failed to connect from server")
})
};
// Perform the client-side of the connection
let mut client = Connection::client(t1, DummyAuthHandler)
let mut client = Connection::client(t1, DummyAuthHandler, server_version!())
.await
.expect("Failed to connect from client");
@ -1093,6 +1265,9 @@ mod tests {
// Spawn a task to perform the client reconnection so we don't deadlock
let task = tokio::spawn(async move { client.reconnect().await.unwrap() });
// Send a version, although it'll be ignored by a reconnecting client
send_server_version!(transport);
// Send garbage to fail handshake from server-side
transport.write_frame(b"hello").await.unwrap();
@ -1108,6 +1283,9 @@ mod tests {
// Spawn a task to perform the client reconnection so we don't deadlock
let task = tokio::spawn(async move { client.reconnect().await.unwrap() });
// Send a version, although it'll be ignored by a reconnecting client
send_server_version!(transport);
// Perform first step of completing server-side of handshake
transport.server_handshake().await.unwrap();
@ -1126,6 +1304,9 @@ mod tests {
// Spawn a task to perform the client reconnection so we don't deadlock
let task = tokio::spawn(async move { client.reconnect().await.unwrap() });
// Send a version, although it'll be ignored by a reconnecting client
send_server_version!(transport);
// Perform first step of completing server-side of handshake
transport.server_handshake().await.unwrap();
@ -1162,6 +1343,9 @@ mod tests {
// Spawn a task to perform the client reconnection so we don't deadlock
let task = tokio::spawn(async move { client.reconnect().await.unwrap() });
// Send a version, although it'll be ignored by a reconnecting client
send_server_version!(transport);
// Perform first step of completing server-side of handshake
transport.server_handshake().await.unwrap();
@ -1205,6 +1389,9 @@ mod tests {
client
});
// Send a version, although it'll be ignored by a reconnecting client
send_server_version!(transport);
// Perform first step of completing server-side of handshake
transport.server_handshake().await.unwrap();
@ -1275,7 +1462,7 @@ mod tests {
// Spawn a task to perform the server reconnection so we don't deadlock
let task = tokio::spawn(async move {
Connection::server(transport, &verifier, keychain)
Connection::server(transport, &verifier, keychain, server_version!())
.await
.expect("Failed to connect from server")
});

@ -0,0 +1,119 @@
use semver::{Comparator, Op, Prerelease, Version as SemVer, VersionReq};
use std::fmt;
/// Represents a version and compatibility rules.
#[derive(Clone, Debug)]
pub struct Version {
inner: SemVer,
rules: VersionReq,
}
impl Version {
/// Creates a new version in the form `major.minor.patch` with a ruleset that is used to check
/// other versions such that `>=0.1.2, <0.2.0` or `>=1.2.3, <2` depending on whether or not the
/// major version is `0`.
///
/// ```
/// use distant_net::common::Version;
///
/// // Matching versions are compatible
/// let a = Version::new(1, 2, 3);
/// let b = Version::new(1, 2, 3);
/// assert!(a.is_compatible_with(&b));
///
/// // Version 1.2.3 is compatible with 1.2.4, but not the other way
/// let a = Version::new(1, 2, 3);
/// let b = Version::new(1, 2, 4);
/// assert!(a.is_compatible_with(&b));
/// assert!(!b.is_compatible_with(&a));
///
/// // Version 1.2.3 is compatible with 1.3.0, but not 2
/// let a = Version::new(1, 2, 3);
/// assert!(a.is_compatible_with(&Version::new(1, 3, 0)));
/// assert!(!a.is_compatible_with(&Version::new(2, 0, 0)));
///
/// // Version 0.1.2 is compatible with 0.1.3, but not the other way
/// let a = Version::new(0, 1, 2);
/// let b = Version::new(0, 1, 3);
/// assert!(a.is_compatible_with(&b));
/// assert!(!b.is_compatible_with(&a));
///
/// // Version 0.1.2 is not compatible with 0.2
/// let a = Version::new(0, 1, 2);
/// let b = Version::new(0, 2, 0);
/// assert!(!a.is_compatible_with(&b));
/// assert!(!b.is_compatible_with(&a));
/// ```
pub fn new(major: u64, minor: u64, patch: u64) -> Self {
Self {
inner: SemVer::new(major, minor, patch),
rules: VersionReq {
comparators: vec![
Comparator {
op: Op::GreaterEq,
major,
minor: Some(minor),
patch: Some(patch),
pre: Prerelease::EMPTY,
},
Comparator {
op: Op::Less,
major: if major == 0 { 0 } else { major + 1 },
minor: if major == 0 { Some(minor + 1) } else { None },
patch: None,
pre: Prerelease::EMPTY,
},
],
},
}
}
/// Returns true if this version is compatible with another version.
pub fn is_compatible_with(&self, other: &Self) -> bool {
self.rules.matches(&other.inner)
}
/// Converts from a collection of bytes into a version using the byte form major/minor/patch
/// using big endian.
pub fn from_be_bytes(bytes: [u8; 24]) -> Self {
Self::new(
u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]),
u64::from_be_bytes([
bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14],
bytes[15],
]),
u64::from_be_bytes([
bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22],
bytes[23],
]),
)
}
/// Converts the version into a byte form of major/minor/patch using big endian.
pub const fn to_be_bytes(&self) -> [u8; 24] {
let major = self.inner.major.to_be_bytes();
let minor = self.inner.minor.to_be_bytes();
let patch = self.inner.patch.to_be_bytes();
[
major[0], major[1], major[2], major[3], major[4], major[5], major[6], major[7],
minor[0], minor[1], minor[2], minor[3], minor[4], minor[5], minor[6], minor[7],
patch[0], patch[1], patch[2], patch[3], patch[4], patch[5], patch[6], patch[7],
]
}
}
impl Default for Version {
/// Default version is `0.0.0`.
fn default() -> Self {
Self::new(0, 0, 0)
}
}
impl fmt::Display for Version {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.inner)
}
}

@ -9,7 +9,7 @@ use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::sync::{broadcast, RwLock};
use crate::common::{ConnectionId, Listener, Response, Transport};
use crate::common::{ConnectionId, Listener, Response, Transport, Version};
mod builder;
pub use builder::*;
@ -45,6 +45,9 @@ pub struct Server<T> {
/// Performs authentication using various methods
verifier: Verifier,
/// Version associated with the server used by clients to verify compatibility
version: Version,
}
/// Interface for a handler that receives connections and requests
@ -81,6 +84,7 @@ impl Server<()> {
config: Default::default(),
handler: (),
verifier: Verifier::empty(),
version: Default::default(),
}
}
@ -115,6 +119,7 @@ impl<T> Server<T> {
config,
handler: self.handler,
verifier: self.verifier,
version: self.version,
}
}
@ -124,6 +129,7 @@ impl<T> Server<T> {
config: self.config,
handler,
verifier: self.verifier,
version: self.version,
}
}
@ -133,6 +139,17 @@ impl<T> Server<T> {
config: self.config,
handler: self.handler,
verifier,
version: self.version,
}
}
/// Consumes the current server, replacing its version with `version` and returning it.
pub fn version(self, version: Version) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier: self.verifier,
version,
}
}
}
@ -172,6 +189,7 @@ where
config,
handler,
verifier,
version,
} = self;
let handler = Arc::new(handler);
@ -221,6 +239,7 @@ where
.sleep_duration(config.connection_sleep)
.heartbeat_duration(config.connection_heartbeat)
.verifier(Arc::downgrade(&verifier))
.version(version.clone())
.spawn(),
);
@ -253,6 +272,12 @@ mod tests {
use super::*;
use crate::common::{Connection, InmemoryTransport, MpscListener, Request, Response};
macro_rules! server_version {
() => {
Version::new(1, 2, 3)
};
}
pub struct TestServerHandler;
#[async_trait]
@ -275,6 +300,7 @@ mod tests {
config,
handler: TestServerHandler,
verifier: Verifier::new(methods),
version: server_version!(),
}
}
@ -304,7 +330,7 @@ mod tests {
.expect("Failed to start server");
// Perform handshake and authentication with the server before beginning to send data
let mut connection = Connection::client(transport, DummyAuthHandler)
let mut connection = Connection::client(transport, DummyAuthHandler, server_version!())
.await
.expect("Failed to connect to server");

@ -5,7 +5,7 @@ use distant_auth::Verifier;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::common::{PortRange, TcpListener};
use crate::common::{PortRange, TcpListener, Version};
use crate::server::{Server, ServerConfig, ServerHandler, TcpServerRef};
pub struct TcpServerBuilder<T>(Server<T>);
@ -35,6 +35,10 @@ impl<T> TcpServerBuilder<T> {
pub fn verifier(self, verifier: Verifier) -> Self {
Self(self.0.verifier(verifier))
}
pub fn version(self, version: Version) -> Self {
Self(self.0.version(version))
}
}
impl<T> TcpServerBuilder<T>

@ -5,7 +5,7 @@ use distant_auth::Verifier;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::common::UnixSocketListener;
use crate::common::{UnixSocketListener, Version};
use crate::server::{Server, ServerConfig, ServerHandler, UnixSocketServerRef};
pub struct UnixSocketServerBuilder<T>(Server<T>);
@ -35,6 +35,10 @@ impl<T> UnixSocketServerBuilder<T> {
pub fn verifier(self, verifier: Verifier) -> Self {
Self(self.0.verifier(verifier))
}
pub fn version(self, version: Version) -> Self {
Self(self.0.version(version))
}
}
impl<T> UnixSocketServerBuilder<T>

@ -5,7 +5,7 @@ use distant_auth::Verifier;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::common::WindowsPipeListener;
use crate::common::{Version, WindowsPipeListener};
use crate::server::{Server, ServerConfig, ServerHandler, WindowsPipeServerRef};
pub struct WindowsPipeServerBuilder<T>(Server<T>);
@ -35,6 +35,10 @@ impl<T> WindowsPipeServerBuilder<T> {
pub fn verifier(self, verifier: Verifier) -> Self {
Self(self.0.verifier(verifier))
}
pub fn version(self, version: Version) -> Self {
Self(self.0.version(version))
}
}
impl<T> WindowsPipeServerBuilder<T>

@ -14,7 +14,7 @@ use tokio::task::JoinHandle;
use super::{ConnectionState, RequestCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer};
use crate::common::{
Backup, Connection, Frame, Interest, Keychain, Response, Transport, UntypedRequest,
Backup, Connection, Frame, Interest, Keychain, Response, Transport, UntypedRequest, Version,
};
pub type ServerKeychain = Keychain<oneshot::Receiver<Backup>>;
@ -65,6 +65,7 @@ pub(super) struct ConnectionTaskBuilder<H, S, T> {
sleep_duration: Duration,
heartbeat_duration: Duration,
verifier: Weak<Verifier>,
version: Version,
}
impl ConnectionTaskBuilder<(), (), ()> {
@ -80,6 +81,7 @@ impl ConnectionTaskBuilder<(), (), ()> {
sleep_duration: SLEEP_DURATION,
heartbeat_duration: MINIMUM_HEARTBEAT_DURATION,
verifier: Weak::new(),
version: Version::default(),
}
}
}
@ -96,6 +98,7 @@ impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
version: self.version,
}
}
@ -110,6 +113,7 @@ impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
version: self.version,
}
}
@ -124,6 +128,7 @@ impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
version: self.version,
}
}
@ -138,6 +143,7 @@ impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
version: self.version,
}
}
@ -152,6 +158,7 @@ impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
version: self.version,
}
}
@ -169,6 +176,7 @@ impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
version: self.version,
}
}
@ -183,6 +191,7 @@ impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
version: self.version,
}
}
@ -200,6 +209,7 @@ impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
sleep_duration: self.sleep_duration,
heartbeat_duration,
verifier: self.verifier,
version: self.version,
}
}
@ -214,6 +224,22 @@ impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier,
version: self.version,
}
}
pub fn version(self, version: Version) -> ConnectionTaskBuilder<H, S, T> {
ConnectionTaskBuilder {
handler: self.handler,
state: self.state,
keychain: self.keychain,
transport: self.transport,
shutdown: self.shutdown,
shutdown_timer: self.shutdown_timer,
sleep_duration: self.sleep_duration,
heartbeat_duration: self.heartbeat_duration,
verifier: self.verifier,
version,
}
}
}
@ -240,6 +266,7 @@ where
sleep_duration,
heartbeat_duration,
verifier,
version,
} = self;
// NOTE: This exists purely to make the compiler happy for macro_rules declaration order.
@ -408,7 +435,8 @@ where
match await_or_shutdown!(Box::pin(Connection::server(
transport,
verifier.as_ref(),
keychain
keychain,
version
))) {
Ok(connection) => connection,
Err(x) => {
@ -627,6 +655,12 @@ mod tests {
}};
}
macro_rules! server_version {
() => {
Version::new(1, 2, 3)
};
}
#[test(tokio::test)]
async fn should_terminate_if_fails_access_verifier() {
let handler = Arc::new(TestServerHandler);
@ -671,11 +705,12 @@ mod tests {
.transport(t1)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.verifier(Arc::downgrade(&verifier))
.version(server_version!())
.spawn();
// Spawn a task to handle establishing connection from client-side
tokio::spawn(async move {
let _client = Connection::client(t2, DummyAuthHandler)
let _client = Connection::client(t2, DummyAuthHandler, server_version!())
.await
.expect("Fail to establish client-side connection");
});
@ -704,11 +739,12 @@ mod tests {
.transport(t1)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.verifier(Arc::downgrade(&verifier))
.version(server_version!())
.spawn();
// Spawn a task to handle establishing connection from client-side
tokio::spawn(async move {
let _client = Connection::client(t2, DummyAuthHandler)
let _client = Connection::client(t2, DummyAuthHandler, server_version!())
.await
.expect("Fail to establish client-side connection");
});
@ -754,12 +790,13 @@ mod tests {
.transport(t1)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.verifier(Arc::downgrade(&verifier))
.version(server_version!())
.spawn();
// Spawn a task to handle establishing connection from client-side, and then closes to
// trigger the server-side to close
tokio::spawn(async move {
let _client = Connection::client(t2, DummyAuthHandler)
let _client = Connection::client(t2, DummyAuthHandler, server_version!())
.await
.expect("Fail to establish client-side connection");
});
@ -828,12 +865,13 @@ mod tests {
})
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.verifier(Arc::downgrade(&verifier))
.version(server_version!())
.spawn();
// Spawn a task to handle establishing connection from client-side, set ready to fail
// for the server-side after client connection completes, and wait a bit
tokio::spawn(async move {
let _client = Connection::client(t2, DummyAuthHandler)
let _client = Connection::client(t2, DummyAuthHandler, server_version!())
.await
.expect("Fail to establish client-side connection");
@ -872,12 +910,13 @@ mod tests {
.transport(t1)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.verifier(Arc::downgrade(&verifier))
.version(server_version!())
.spawn();
// Spawn a task to handle establishing connection from client-side, and then closes to
// trigger the server-side to close
tokio::spawn(async move {
let _client = Connection::client(t2, DummyAuthHandler)
let _client = Connection::client(t2, DummyAuthHandler, server_version!())
.await
.expect("Fail to establish client-side connection");
});
@ -902,11 +941,12 @@ mod tests {
.transport(t1)
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.verifier(Arc::downgrade(&verifier))
.version(server_version!())
.spawn();
// Spawn a task to handle establishing connection from client-side
let task = tokio::spawn(async move {
let mut client = Connection::client(t2, DummyAuthHandler)
let mut client = Connection::client(t2, DummyAuthHandler, server_version!())
.await
.expect("Fail to establish client-side connection");
@ -939,11 +979,12 @@ mod tests {
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.version(server_version!())
.spawn();
// Spawn a task to handle establishing connection from client-side
let task = tokio::spawn(async move {
let mut client = Connection::client(t2, DummyAuthHandler)
let mut client = Connection::client(t2, DummyAuthHandler, server_version!())
.await
.expect("Fail to establish client-side connection");
@ -1047,10 +1088,12 @@ mod tests {
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.version(server_version!())
.spawn();
// Spawn a task to handle the client-side establishment of a full connection
let _client_task = tokio::spawn(Connection::client(t2, DummyAuthHandler));
let _client_task =
tokio::spawn(Connection::client(t2, DummyAuthHandler, server_version!()));
// Shutdown server connection task while it is accepting the connection, verifying that we
// do not get an error in return
@ -1099,10 +1142,12 @@ mod tests {
.shutdown_timer(Arc::downgrade(&shutdown_timer))
.heartbeat_duration(Duration::from_millis(200))
.verifier(Arc::downgrade(&verifier))
.version(server_version!())
.spawn();
// Spawn a task to handle the client-side establishment of a full connection
let _client_task = tokio::spawn(Connection::client(t2, DummyAuthHandler));
let _client_task =
tokio::spawn(Connection::client(t2, DummyAuthHandler, server_version!()));
// Wait to ensure we complete the accept call first
let _ = rx.recv().await;

@ -19,8 +19,9 @@ use async_compat::CompatExt;
use async_trait::async_trait;
use distant_core::net::auth::{AuthHandlerMap, DummyAuthHandler, Verifier};
use distant_core::net::client::{Client, ClientConfig};
use distant_core::net::common::{Host, InmemoryTransport, OneshotListener};
use distant_core::net::common::{Host, InmemoryTransport, OneshotListener, Version};
use distant_core::net::server::{Server, ServerRef};
use distant_core::protocol::PROTOCOL_VERSION;
use distant_core::{DistantApiServerHandler, DistantClient, DistantSingleKeyCredentials};
use log::*;
use smol::channel::Receiver as SmolReceiver;
@ -588,6 +589,11 @@ impl Ssh {
match Client::tcp(addr)
.auth_handler(AuthHandlerMap::new().with_static_key(key.clone()))
.connect_timeout(timeout)
.version(Version::new(
PROTOCOL_VERSION.major,
PROTOCOL_VERSION.minor,
PROTOCOL_VERSION.patch,
))
.connect()
.await
{
@ -756,6 +762,11 @@ impl Ssh {
.auth_handler(DummyAuthHandler)
.config(ClientConfig::default().with_maximum_silence_duration())
.connector(t1)
.version(Version::new(
PROTOCOL_VERSION.major,
PROTOCOL_VERSION.minor,
PROTOCOL_VERSION.patch,
))
.connect()
.await?;
Ok((client, server))

Loading…
Cancel
Save