From cf2c5c700e4e34020c3ec5a764f62a63de4f0399 Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Fri, 14 Jul 2023 02:07:28 -0500 Subject: [PATCH] Unfinished refactoring to support configuring versions for connecting -- need to find all places where we build client/server and set version to protocol OR net package version for manager --- distant-core/tests/api_tests.rs | 13 +- distant-net/src/client/builder.rs | 22 +- distant-net/src/common.rs | 2 + distant-net/src/common/connection.rs | 241 +++++++++++++++++++--- distant-net/src/common/version.rs | 119 +++++++++++ distant-net/src/server.rs | 30 ++- distant-net/src/server/builder/tcp.rs | 6 +- distant-net/src/server/builder/unix.rs | 6 +- distant-net/src/server/builder/windows.rs | 6 +- distant-net/src/server/connection.rs | 67 +++++- distant-ssh2/src/lib.rs | 13 +- 11 files changed, 478 insertions(+), 47 deletions(-) create mode 100644 distant-net/src/common/version.rs diff --git a/distant-core/tests/api_tests.rs b/distant-core/tests/api_tests.rs index 516b3f3..7b81fdb 100644 --- a/distant-core/tests/api_tests.rs +++ b/distant-core/tests/api_tests.rs @@ -7,8 +7,9 @@ use distant_core::{ }; use distant_net::auth::{DummyAuthHandler, Verifier}; use distant_net::client::Client; -use distant_net::common::{InmemoryTransport, OneshotListener}; +use distant_net::common::{InmemoryTransport, OneshotListener, Version}; use distant_net::server::{Server, ServerRef}; +use distant_protocol::PROTOCOL_VERSION; /// Stands up an inmemory client and server using the given api. async fn setup(api: impl DistantApi + Send + Sync + 'static) -> (DistantClient, ServerRef) { @@ -17,12 +18,22 @@ async fn setup(api: impl DistantApi + Send + Sync + 'static) -> (DistantClient, let server = Server::new() .handler(DistantApiServerHandler::new(api)) .verifier(Verifier::none()) + .version(Version::new( + PROTOCOL_VERSION.major, + PROTOCOL_VERSION.minor, + PROTOCOL_VERSION.patch, + )) .start(OneshotListener::from_value(t2)) .expect("Failed to start server"); let client: DistantClient = Client::build() .auth_handler(DummyAuthHandler) .connector(t1) + .version(Version::new( + PROTOCOL_VERSION.major, + PROTOCOL_VERSION.minor, + PROTOCOL_VERSION.patch, + )) .connect() .await .expect("Failed to connect to server"); diff --git a/distant-net/src/client/builder.rs b/distant-net/src/client/builder.rs index 10d64b1..9e02495 100644 --- a/distant-net/src/client/builder.rs +++ b/distant-net/src/client/builder.rs @@ -20,7 +20,7 @@ pub use windows::*; use super::ClientConfig; use crate::client::{Client, UntypedClient}; -use crate::common::{Connection, Transport}; +use crate::common::{Connection, Transport, Version}; /// Interface that performs the connection to produce a [`Transport`] for use by the [`Client`]. #[async_trait] @@ -46,6 +46,7 @@ pub struct ClientBuilder { connector: C, config: ClientConfig, connect_timeout: Option, + version: Version, } impl ClientBuilder { @@ -56,6 +57,7 @@ impl ClientBuilder { config: self.config, connector: self.connector, connect_timeout: self.connect_timeout, + version: self.version, } } @@ -66,6 +68,7 @@ impl ClientBuilder { config, connector: self.connector, connect_timeout: self.connect_timeout, + version: self.version, } } @@ -76,6 +79,7 @@ impl ClientBuilder { config: self.config, connector, connect_timeout: self.connect_timeout, + version: self.version, } } @@ -86,6 +90,18 @@ impl ClientBuilder { config: self.config, connector: self.connector, connect_timeout: connect_timeout.into(), + version: self.version, + } + } + + /// Configure the version of the client. + pub fn version(self, version: Version) -> Self { + Self { + auth_handler: self.auth_handler, + config: self.config, + connector: self.connector, + connect_timeout: self.connect_timeout, + version, } } } @@ -97,6 +113,7 @@ impl ClientBuilder<(), ()> { config: Default::default(), connector: (), connect_timeout: None, + version: Default::default(), } } } @@ -119,6 +136,7 @@ where let auth_handler = self.auth_handler; let config = self.config; let connect_timeout = self.connect_timeout; + let version = self.version; let f = async move { let transport = match connect_timeout { @@ -128,7 +146,7 @@ where .and_then(convert::identity)?, None => self.connector.connect().await?, }; - let connection = Connection::client(transport, auth_handler).await?; + let connection = Connection::client(transport, auth_handler, version).await?; Ok(UntypedClient::spawn(connection, config)) }; diff --git a/distant-net/src/common.rs b/distant-net/src/common.rs index 5f793c8..2f24ba4 100644 --- a/distant-net/src/common.rs +++ b/distant-net/src/common.rs @@ -9,6 +9,7 @@ mod packet; mod port; mod transport; pub(crate) mod utils; +mod version; pub use any::*; pub(crate) use connection::Connection; @@ -21,3 +22,4 @@ pub use map::*; pub use packet::*; pub use port::*; pub use transport::*; +pub use version::*; diff --git a/distant-net/src/common/connection.rs b/distant-net/src/common/connection.rs index acb434b..ef6efe4 100644 --- a/distant-net/src/common/connection.rs +++ b/distant-net/src/common/connection.rs @@ -11,6 +11,7 @@ use tokio::sync::oneshot; use crate::common::InmemoryTransport; use crate::common::{ Backup, FramedTransport, HeapSecretKey, Keychain, KeychainResult, Reconnectable, Transport, + TransportExt, Version, }; /// Id of the connection @@ -110,6 +111,19 @@ where debug!("[Conn {id}] Re-establishing connection"); Reconnectable::reconnect(transport).await?; + // Wait for exactly version bytes (24 where 8 bytes for major, minor, patch) + // but with a reconnect we don't actually validate it because we did that + // the first time we connected + // + // NOTE: We do this with the raw transport and not the framed version! + debug!("[Conn {id}] Waiting for server version"); + if transport.as_mut_inner().read_exact(&mut [0u8; 24]).await? != 24 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Wrong version byte len received", + )); + } + // Perform a handshake to ensure that the connection is properly established and encrypted debug!("[Conn {id}] Performing handshake"); transport.client_handshake().await?; @@ -190,13 +204,42 @@ where /// Transforms a raw [`Transport`] into an established [`Connection`] from the client-side by /// performing the following: /// - /// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use - /// 2. Authenticates the established connection to ensure it is valid - /// 3. Restores pre-existing state using the provided backup, replaying any missing frames and + /// 1. Performs a version check with the server + /// 2. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use + /// 3. Authenticates the established connection to ensure it is valid + /// 4. Restores pre-existing state using the provided backup, replaying any missing frames and /// receiving any frames from the other side - pub async fn client(transport: T, handler: H) -> io::Result { + pub async fn client( + transport: T, + handler: H, + version: Version, + ) -> io::Result { let id: ConnectionId = rand::random(); + // Wait for exactly version bytes (24 where 8 bytes for major, minor, patch) + debug!("[Conn {id}] Waiting for server version"); + let mut version_bytes = [0u8; 24]; + if transport.read_exact(&mut version_bytes).await? != 24 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Wrong version byte len received", + )); + } + + // Compare versions for compatibility and drop the connection if incompatible + let server_version = Version::from_be_bytes(version_bytes); + debug!( + "[Conn {id}] Checking compatibility between client {version} & server {server_version}" + ); + if !version.is_compatible_with(&server_version) { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "Client version {version} is incompatible with server version {server_version}" + ), + )); + } + // Perform a handshake to ensure that the connection is properly established and encrypted debug!("[Conn {id}] Performing handshake"); let mut transport: FramedTransport = @@ -238,19 +281,25 @@ where /// Transforms a raw [`Transport`] into an established [`Connection`] from the server-side by /// performing the following: /// - /// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use - /// 2. Authenticates the established connection to ensure it is valid by either using the + /// 1. Performs a version check with the client + /// 2. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use + /// 3. Authenticates the established connection to ensure it is valid by either using the /// given `verifier` or, if working with an existing client connection, will validate an OTP /// from our database - /// 3. Restores pre-existing state using the provided backup, replaying any missing frames and + /// 4. Restores pre-existing state using the provided backup, replaying any missing frames and /// receiving any frames from the other side pub async fn server( transport: T, verifier: &Verifier, keychain: Keychain>, + version: Version, ) -> io::Result { let id: ConnectionId = rand::random(); + // Write the version as bytes + debug!("[Conn {id}] Sending version {version}"); + transport.write_all(&version.to_be_bytes()).await?; + // Perform a handshake to ensure that the connection is properly established and encrypted debug!("[Conn {id}] Performing handshake"); let mut transport: FramedTransport = @@ -464,6 +513,60 @@ mod tests { use super::*; use crate::common::Frame; + macro_rules! server_version { + () => { + Version::new(1, 2, 3) + }; + } + + macro_rules! send_server_version { + ($transport:expr, $version:expr) => {{ + ($transport) + .as_mut_inner() + .write_all(&$version.to_be_bytes()) + .await + .unwrap(); + }}; + ($transport:expr) => { + send_server_version!($transport, server_version!()); + }; + } + + macro_rules! receive_version { + ($transport:expr) => {{ + let mut bytes = [0u8; 24]; + assert_eq!( + ($transport) + .as_mut_inner() + .read_exact(&mut bytes) + .await + .unwrap(), + 24, + "Wrong version len received" + ); + Version::from_be_bytes(bytes) + }}; + } + + #[test(tokio::test)] + async fn client_should_fail_when_server_sends_incompatible_version() { + let (mut t1, t2) = FramedTransport::pair(100); + + // Spawn a task to perform the client connection so we don't deadlock while simulating the + // server actions on the other side + let task = tokio::spawn(async move { + Connection::client(t2.into_inner(), DummyAuthHandler, Version::new(1, 2, 3)) + .await + .unwrap() + }); + + // Send invalid version to fail the handshake + send_server_version!(t1, Version::new(2, 0, 0)); + + // Client should fail + task.await.unwrap_err(); + } + #[test(tokio::test)] async fn client_should_fail_if_codec_handshake_fails() { let (mut t1, t2) = FramedTransport::pair(100); @@ -471,11 +574,14 @@ mod tests { // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { - Connection::client(t2.into_inner(), DummyAuthHandler) + Connection::client(t2.into_inner(), DummyAuthHandler, server_version!()) .await .unwrap() }); + // Send server version for client to confirm + send_server_version!(t1); + // Send garbage to fail the handshake t1.write_frame(Frame::new(b"invalid")).await.unwrap(); @@ -490,11 +596,14 @@ mod tests { // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { - Connection::client(t2.into_inner(), DummyAuthHandler) + Connection::client(t2.into_inner(), DummyAuthHandler, server_version!()) .await .unwrap() }); + // Send server version for client to confirm + send_server_version!(t1); + // Perform first step of connection by establishing the codec t1.server_handshake().await.unwrap(); @@ -519,11 +628,14 @@ mod tests { // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { - Connection::client(t2.into_inner(), DummyAuthHandler) + Connection::client(t2.into_inner(), DummyAuthHandler, server_version!()) .await .unwrap() }); + // Send server version for client to confirm + send_server_version!(t1); + // Perform first step of connection by establishing the codec t1.server_handshake().await.unwrap(); @@ -559,11 +671,14 @@ mod tests { // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { - Connection::client(t2.into_inner(), DummyAuthHandler) + Connection::client(t2.into_inner(), DummyAuthHandler, server_version!()) .await .unwrap() }); + // Send server version for client to confirm + send_server_version!(t1); + // Perform first step of connection by establishing the codec t1.server_handshake().await.unwrap(); @@ -597,11 +712,14 @@ mod tests { // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { - Connection::client(t2.into_inner(), DummyAuthHandler) + Connection::client(t2.into_inner(), DummyAuthHandler, server_version!()) .await .unwrap() }); + // Send server version for client to confirm + send_server_version!(t1); + // Perform first step of connection by establishing the codec t1.server_handshake().await.unwrap(); @@ -629,6 +747,30 @@ mod tests { assert_eq!(client.otp(), Some(&otp)); } + #[test(tokio::test)] + async fn server_should_fail_if_client_drops_due_to_version() { + let (mut t1, t2) = FramedTransport::pair(100); + let verifier = Verifier::none(); + let keychain = Keychain::new(); + + // Spawn a task to perform the server connection so we don't deadlock while simulating the + // client actions on the other side + let task = tokio::spawn(async move { + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) + .await + .unwrap() + }); + + // Receive the version from the server + let _ = receive_version!(t1); + + // Drop client connection as a result of an "incompatible version" + drop(t1); + + // Server should fail + task.await.unwrap_err(); + } + #[test(tokio::test)] async fn server_should_fail_if_codec_handshake_fails() { let (mut t1, t2) = FramedTransport::pair(100); @@ -638,11 +780,14 @@ mod tests { // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() }); + // Receive the version from the server + let _ = receive_version!(t1); + // Send garbage to fail the handshake t1.write_frame(Frame::new(b"invalid")).await.unwrap(); @@ -659,11 +804,14 @@ mod tests { // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() }); + // Receive the version from the server + let _ = receive_version!(t1); + // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); @@ -683,11 +831,14 @@ mod tests { // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() }); + // Receive the version from the server + let _ = receive_version!(t1); + // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); @@ -717,11 +868,14 @@ mod tests { // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() }); + // Receive the version from the server + let _ = receive_version!(t1); + // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); @@ -750,11 +904,14 @@ mod tests { // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() }); + // Receive the version from the server + let _ = receive_version!(t1); + // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); @@ -790,11 +947,14 @@ mod tests { // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() }); + // Receive the version from the server + let _ = receive_version!(t1); + // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); @@ -828,11 +988,14 @@ mod tests { // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() }); + // Receive the version from the server + let _ = receive_version!(t1); + // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); @@ -866,11 +1029,14 @@ mod tests { // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() }); + // Receive the version from the server + let _ = receive_version!(t1); + // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); @@ -904,12 +1070,15 @@ mod tests { let task = tokio::spawn({ let keychain = keychain.clone(); async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() } }); + // Receive the version from the server + let _ = receive_version!(t1); + // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); @@ -969,12 +1138,15 @@ mod tests { let task = tokio::spawn({ let keychain = keychain.clone(); async move { - Connection::server(t2.into_inner(), &verifier, keychain) + Connection::server(t2.into_inner(), &verifier, keychain, server_version!()) .await .unwrap() } }); + // Receive the version from the server + let _ = receive_version!(t1); + // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); @@ -1029,13 +1201,13 @@ mod tests { // Spawn a task to perform the server connection so we don't deadlock let task = tokio::spawn(async move { - Connection::server(t2, &verifier, keychain) + Connection::server(t2, &verifier, keychain, server_version!()) .await .expect("Failed to connect from server") }); // Perform the client-side of the connection - let mut client = Connection::client(t1, DummyAuthHandler) + let mut client = Connection::client(t1, DummyAuthHandler, server_version!()) .await .expect("Failed to connect from client"); let mut server = task.await.unwrap(); @@ -1063,14 +1235,14 @@ mod tests { let verifier = Arc::clone(&verifier); let keychain = keychain.clone(); tokio::spawn(async move { - Connection::server(t2, &verifier, keychain) + Connection::server(t2, &verifier, keychain, server_version!()) .await .expect("Failed to connect from server") }) }; // Perform the client-side of the connection - let mut client = Connection::client(t1, DummyAuthHandler) + let mut client = Connection::client(t1, DummyAuthHandler, server_version!()) .await .expect("Failed to connect from client"); @@ -1093,6 +1265,9 @@ mod tests { // Spawn a task to perform the client reconnection so we don't deadlock let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); + // Send a version, although it'll be ignored by a reconnecting client + send_server_version!(transport); + // Send garbage to fail handshake from server-side transport.write_frame(b"hello").await.unwrap(); @@ -1108,6 +1283,9 @@ mod tests { // Spawn a task to perform the client reconnection so we don't deadlock let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); + // Send a version, although it'll be ignored by a reconnecting client + send_server_version!(transport); + // Perform first step of completing server-side of handshake transport.server_handshake().await.unwrap(); @@ -1126,6 +1304,9 @@ mod tests { // Spawn a task to perform the client reconnection so we don't deadlock let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); + // Send a version, although it'll be ignored by a reconnecting client + send_server_version!(transport); + // Perform first step of completing server-side of handshake transport.server_handshake().await.unwrap(); @@ -1162,6 +1343,9 @@ mod tests { // Spawn a task to perform the client reconnection so we don't deadlock let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); + // Send a version, although it'll be ignored by a reconnecting client + send_server_version!(transport); + // Perform first step of completing server-side of handshake transport.server_handshake().await.unwrap(); @@ -1205,6 +1389,9 @@ mod tests { client }); + // Send a version, although it'll be ignored by a reconnecting client + send_server_version!(transport); + // Perform first step of completing server-side of handshake transport.server_handshake().await.unwrap(); @@ -1275,7 +1462,7 @@ mod tests { // Spawn a task to perform the server reconnection so we don't deadlock let task = tokio::spawn(async move { - Connection::server(transport, &verifier, keychain) + Connection::server(transport, &verifier, keychain, server_version!()) .await .expect("Failed to connect from server") }); diff --git a/distant-net/src/common/version.rs b/distant-net/src/common/version.rs new file mode 100644 index 0000000..b9f9cae --- /dev/null +++ b/distant-net/src/common/version.rs @@ -0,0 +1,119 @@ +use semver::{Comparator, Op, Prerelease, Version as SemVer, VersionReq}; +use std::fmt; + +/// Represents a version and compatibility rules. +#[derive(Clone, Debug)] +pub struct Version { + inner: SemVer, + rules: VersionReq, +} + +impl Version { + /// Creates a new version in the form `major.minor.patch` with a ruleset that is used to check + /// other versions such that `>=0.1.2, <0.2.0` or `>=1.2.3, <2` depending on whether or not the + /// major version is `0`. + /// + /// ``` + /// use distant_net::common::Version; + /// + /// // Matching versions are compatible + /// let a = Version::new(1, 2, 3); + /// let b = Version::new(1, 2, 3); + /// assert!(a.is_compatible_with(&b)); + /// + /// // Version 1.2.3 is compatible with 1.2.4, but not the other way + /// let a = Version::new(1, 2, 3); + /// let b = Version::new(1, 2, 4); + /// assert!(a.is_compatible_with(&b)); + /// assert!(!b.is_compatible_with(&a)); + /// + /// // Version 1.2.3 is compatible with 1.3.0, but not 2 + /// let a = Version::new(1, 2, 3); + /// assert!(a.is_compatible_with(&Version::new(1, 3, 0))); + /// assert!(!a.is_compatible_with(&Version::new(2, 0, 0))); + /// + /// // Version 0.1.2 is compatible with 0.1.3, but not the other way + /// let a = Version::new(0, 1, 2); + /// let b = Version::new(0, 1, 3); + /// assert!(a.is_compatible_with(&b)); + /// assert!(!b.is_compatible_with(&a)); + /// + /// // Version 0.1.2 is not compatible with 0.2 + /// let a = Version::new(0, 1, 2); + /// let b = Version::new(0, 2, 0); + /// assert!(!a.is_compatible_with(&b)); + /// assert!(!b.is_compatible_with(&a)); + /// ``` + pub fn new(major: u64, minor: u64, patch: u64) -> Self { + Self { + inner: SemVer::new(major, minor, patch), + rules: VersionReq { + comparators: vec![ + Comparator { + op: Op::GreaterEq, + major, + minor: Some(minor), + patch: Some(patch), + pre: Prerelease::EMPTY, + }, + Comparator { + op: Op::Less, + major: if major == 0 { 0 } else { major + 1 }, + minor: if major == 0 { Some(minor + 1) } else { None }, + patch: None, + pre: Prerelease::EMPTY, + }, + ], + }, + } + } + + /// Returns true if this version is compatible with another version. + pub fn is_compatible_with(&self, other: &Self) -> bool { + self.rules.matches(&other.inner) + } + + /// Converts from a collection of bytes into a version using the byte form major/minor/patch + /// using big endian. + pub fn from_be_bytes(bytes: [u8; 24]) -> Self { + Self::new( + u64::from_be_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]), + u64::from_be_bytes([ + bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], + bytes[15], + ]), + u64::from_be_bytes([ + bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22], + bytes[23], + ]), + ) + } + + /// Converts the version into a byte form of major/minor/patch using big endian. + pub const fn to_be_bytes(&self) -> [u8; 24] { + let major = self.inner.major.to_be_bytes(); + let minor = self.inner.minor.to_be_bytes(); + let patch = self.inner.patch.to_be_bytes(); + + [ + major[0], major[1], major[2], major[3], major[4], major[5], major[6], major[7], + minor[0], minor[1], minor[2], minor[3], minor[4], minor[5], minor[6], minor[7], + patch[0], patch[1], patch[2], patch[3], patch[4], patch[5], patch[6], patch[7], + ] + } +} + +impl Default for Version { + /// Default version is `0.0.0`. + fn default() -> Self { + Self::new(0, 0, 0) + } +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.inner) + } +} diff --git a/distant-net/src/server.rs b/distant-net/src/server.rs index 6f5c677..218b996 100644 --- a/distant-net/src/server.rs +++ b/distant-net/src/server.rs @@ -9,7 +9,7 @@ use serde::de::DeserializeOwned; use serde::Serialize; use tokio::sync::{broadcast, RwLock}; -use crate::common::{ConnectionId, Listener, Response, Transport}; +use crate::common::{ConnectionId, Listener, Response, Transport, Version}; mod builder; pub use builder::*; @@ -45,6 +45,9 @@ pub struct Server { /// Performs authentication using various methods verifier: Verifier, + + /// Version associated with the server used by clients to verify compatibility + version: Version, } /// Interface for a handler that receives connections and requests @@ -81,6 +84,7 @@ impl Server<()> { config: Default::default(), handler: (), verifier: Verifier::empty(), + version: Default::default(), } } @@ -115,6 +119,7 @@ impl Server { config, handler: self.handler, verifier: self.verifier, + version: self.version, } } @@ -124,6 +129,7 @@ impl Server { config: self.config, handler, verifier: self.verifier, + version: self.version, } } @@ -133,6 +139,17 @@ impl Server { config: self.config, handler: self.handler, verifier, + version: self.version, + } + } + + /// Consumes the current server, replacing its version with `version` and returning it. + pub fn version(self, version: Version) -> Self { + Self { + config: self.config, + handler: self.handler, + verifier: self.verifier, + version, } } } @@ -172,6 +189,7 @@ where config, handler, verifier, + version, } = self; let handler = Arc::new(handler); @@ -221,6 +239,7 @@ where .sleep_duration(config.connection_sleep) .heartbeat_duration(config.connection_heartbeat) .verifier(Arc::downgrade(&verifier)) + .version(version.clone()) .spawn(), ); @@ -253,6 +272,12 @@ mod tests { use super::*; use crate::common::{Connection, InmemoryTransport, MpscListener, Request, Response}; + macro_rules! server_version { + () => { + Version::new(1, 2, 3) + }; + } + pub struct TestServerHandler; #[async_trait] @@ -275,6 +300,7 @@ mod tests { config, handler: TestServerHandler, verifier: Verifier::new(methods), + version: server_version!(), } } @@ -304,7 +330,7 @@ mod tests { .expect("Failed to start server"); // Perform handshake and authentication with the server before beginning to send data - let mut connection = Connection::client(transport, DummyAuthHandler) + let mut connection = Connection::client(transport, DummyAuthHandler, server_version!()) .await .expect("Failed to connect to server"); diff --git a/distant-net/src/server/builder/tcp.rs b/distant-net/src/server/builder/tcp.rs index 2129d69..be35590 100644 --- a/distant-net/src/server/builder/tcp.rs +++ b/distant-net/src/server/builder/tcp.rs @@ -5,7 +5,7 @@ use distant_auth::Verifier; use serde::de::DeserializeOwned; use serde::Serialize; -use crate::common::{PortRange, TcpListener}; +use crate::common::{PortRange, TcpListener, Version}; use crate::server::{Server, ServerConfig, ServerHandler, TcpServerRef}; pub struct TcpServerBuilder(Server); @@ -35,6 +35,10 @@ impl TcpServerBuilder { pub fn verifier(self, verifier: Verifier) -> Self { Self(self.0.verifier(verifier)) } + + pub fn version(self, version: Version) -> Self { + Self(self.0.version(version)) + } } impl TcpServerBuilder diff --git a/distant-net/src/server/builder/unix.rs b/distant-net/src/server/builder/unix.rs index 4bddc1c..1037b65 100644 --- a/distant-net/src/server/builder/unix.rs +++ b/distant-net/src/server/builder/unix.rs @@ -5,7 +5,7 @@ use distant_auth::Verifier; use serde::de::DeserializeOwned; use serde::Serialize; -use crate::common::UnixSocketListener; +use crate::common::{UnixSocketListener, Version}; use crate::server::{Server, ServerConfig, ServerHandler, UnixSocketServerRef}; pub struct UnixSocketServerBuilder(Server); @@ -35,6 +35,10 @@ impl UnixSocketServerBuilder { pub fn verifier(self, verifier: Verifier) -> Self { Self(self.0.verifier(verifier)) } + + pub fn version(self, version: Version) -> Self { + Self(self.0.version(version)) + } } impl UnixSocketServerBuilder diff --git a/distant-net/src/server/builder/windows.rs b/distant-net/src/server/builder/windows.rs index 5d506fa..eb7f3d3 100644 --- a/distant-net/src/server/builder/windows.rs +++ b/distant-net/src/server/builder/windows.rs @@ -5,7 +5,7 @@ use distant_auth::Verifier; use serde::de::DeserializeOwned; use serde::Serialize; -use crate::common::WindowsPipeListener; +use crate::common::{Version, WindowsPipeListener}; use crate::server::{Server, ServerConfig, ServerHandler, WindowsPipeServerRef}; pub struct WindowsPipeServerBuilder(Server); @@ -35,6 +35,10 @@ impl WindowsPipeServerBuilder { pub fn verifier(self, verifier: Verifier) -> Self { Self(self.0.verifier(verifier)) } + + pub fn version(self, version: Version) -> Self { + Self(self.0.version(version)) + } } impl WindowsPipeServerBuilder diff --git a/distant-net/src/server/connection.rs b/distant-net/src/server/connection.rs index 43658f9..540836e 100644 --- a/distant-net/src/server/connection.rs +++ b/distant-net/src/server/connection.rs @@ -14,7 +14,7 @@ use tokio::task::JoinHandle; use super::{ConnectionState, RequestCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer}; use crate::common::{ - Backup, Connection, Frame, Interest, Keychain, Response, Transport, UntypedRequest, + Backup, Connection, Frame, Interest, Keychain, Response, Transport, UntypedRequest, Version, }; pub type ServerKeychain = Keychain>; @@ -65,6 +65,7 @@ pub(super) struct ConnectionTaskBuilder { sleep_duration: Duration, heartbeat_duration: Duration, verifier: Weak, + version: Version, } impl ConnectionTaskBuilder<(), (), ()> { @@ -80,6 +81,7 @@ impl ConnectionTaskBuilder<(), (), ()> { sleep_duration: SLEEP_DURATION, heartbeat_duration: MINIMUM_HEARTBEAT_DURATION, verifier: Weak::new(), + version: Version::default(), } } } @@ -96,6 +98,7 @@ impl ConnectionTaskBuilder { sleep_duration: self.sleep_duration, heartbeat_duration: self.heartbeat_duration, verifier: self.verifier, + version: self.version, } } @@ -110,6 +113,7 @@ impl ConnectionTaskBuilder { sleep_duration: self.sleep_duration, heartbeat_duration: self.heartbeat_duration, verifier: self.verifier, + version: self.version, } } @@ -124,6 +128,7 @@ impl ConnectionTaskBuilder { sleep_duration: self.sleep_duration, heartbeat_duration: self.heartbeat_duration, verifier: self.verifier, + version: self.version, } } @@ -138,6 +143,7 @@ impl ConnectionTaskBuilder { sleep_duration: self.sleep_duration, heartbeat_duration: self.heartbeat_duration, verifier: self.verifier, + version: self.version, } } @@ -152,6 +158,7 @@ impl ConnectionTaskBuilder { sleep_duration: self.sleep_duration, heartbeat_duration: self.heartbeat_duration, verifier: self.verifier, + version: self.version, } } @@ -169,6 +176,7 @@ impl ConnectionTaskBuilder { sleep_duration: self.sleep_duration, heartbeat_duration: self.heartbeat_duration, verifier: self.verifier, + version: self.version, } } @@ -183,6 +191,7 @@ impl ConnectionTaskBuilder { sleep_duration, heartbeat_duration: self.heartbeat_duration, verifier: self.verifier, + version: self.version, } } @@ -200,6 +209,7 @@ impl ConnectionTaskBuilder { sleep_duration: self.sleep_duration, heartbeat_duration, verifier: self.verifier, + version: self.version, } } @@ -214,6 +224,22 @@ impl ConnectionTaskBuilder { sleep_duration: self.sleep_duration, heartbeat_duration: self.heartbeat_duration, verifier, + version: self.version, + } + } + + pub fn version(self, version: Version) -> ConnectionTaskBuilder { + ConnectionTaskBuilder { + handler: self.handler, + state: self.state, + keychain: self.keychain, + transport: self.transport, + shutdown: self.shutdown, + shutdown_timer: self.shutdown_timer, + sleep_duration: self.sleep_duration, + heartbeat_duration: self.heartbeat_duration, + verifier: self.verifier, + version, } } } @@ -240,6 +266,7 @@ where sleep_duration, heartbeat_duration, verifier, + version, } = self; // NOTE: This exists purely to make the compiler happy for macro_rules declaration order. @@ -408,7 +435,8 @@ where match await_or_shutdown!(Box::pin(Connection::server( transport, verifier.as_ref(), - keychain + keychain, + version ))) { Ok(connection) => connection, Err(x) => { @@ -627,6 +655,12 @@ mod tests { }}; } + macro_rules! server_version { + () => { + Version::new(1, 2, 3) + }; + } + #[test(tokio::test)] async fn should_terminate_if_fails_access_verifier() { let handler = Arc::new(TestServerHandler); @@ -671,11 +705,12 @@ mod tests { .transport(t1) .shutdown_timer(Arc::downgrade(&shutdown_timer)) .verifier(Arc::downgrade(&verifier)) + .version(server_version!()) .spawn(); // Spawn a task to handle establishing connection from client-side tokio::spawn(async move { - let _client = Connection::client(t2, DummyAuthHandler) + let _client = Connection::client(t2, DummyAuthHandler, server_version!()) .await .expect("Fail to establish client-side connection"); }); @@ -704,11 +739,12 @@ mod tests { .transport(t1) .shutdown_timer(Arc::downgrade(&shutdown_timer)) .verifier(Arc::downgrade(&verifier)) + .version(server_version!()) .spawn(); // Spawn a task to handle establishing connection from client-side tokio::spawn(async move { - let _client = Connection::client(t2, DummyAuthHandler) + let _client = Connection::client(t2, DummyAuthHandler, server_version!()) .await .expect("Fail to establish client-side connection"); }); @@ -754,12 +790,13 @@ mod tests { .transport(t1) .shutdown_timer(Arc::downgrade(&shutdown_timer)) .verifier(Arc::downgrade(&verifier)) + .version(server_version!()) .spawn(); // Spawn a task to handle establishing connection from client-side, and then closes to // trigger the server-side to close tokio::spawn(async move { - let _client = Connection::client(t2, DummyAuthHandler) + let _client = Connection::client(t2, DummyAuthHandler, server_version!()) .await .expect("Fail to establish client-side connection"); }); @@ -828,12 +865,13 @@ mod tests { }) .shutdown_timer(Arc::downgrade(&shutdown_timer)) .verifier(Arc::downgrade(&verifier)) + .version(server_version!()) .spawn(); // Spawn a task to handle establishing connection from client-side, set ready to fail // for the server-side after client connection completes, and wait a bit tokio::spawn(async move { - let _client = Connection::client(t2, DummyAuthHandler) + let _client = Connection::client(t2, DummyAuthHandler, server_version!()) .await .expect("Fail to establish client-side connection"); @@ -872,12 +910,13 @@ mod tests { .transport(t1) .shutdown_timer(Arc::downgrade(&shutdown_timer)) .verifier(Arc::downgrade(&verifier)) + .version(server_version!()) .spawn(); // Spawn a task to handle establishing connection from client-side, and then closes to // trigger the server-side to close tokio::spawn(async move { - let _client = Connection::client(t2, DummyAuthHandler) + let _client = Connection::client(t2, DummyAuthHandler, server_version!()) .await .expect("Fail to establish client-side connection"); }); @@ -902,11 +941,12 @@ mod tests { .transport(t1) .shutdown_timer(Arc::downgrade(&shutdown_timer)) .verifier(Arc::downgrade(&verifier)) + .version(server_version!()) .spawn(); // Spawn a task to handle establishing connection from client-side let task = tokio::spawn(async move { - let mut client = Connection::client(t2, DummyAuthHandler) + let mut client = Connection::client(t2, DummyAuthHandler, server_version!()) .await .expect("Fail to establish client-side connection"); @@ -939,11 +979,12 @@ mod tests { .shutdown_timer(Arc::downgrade(&shutdown_timer)) .heartbeat_duration(Duration::from_millis(200)) .verifier(Arc::downgrade(&verifier)) + .version(server_version!()) .spawn(); // Spawn a task to handle establishing connection from client-side let task = tokio::spawn(async move { - let mut client = Connection::client(t2, DummyAuthHandler) + let mut client = Connection::client(t2, DummyAuthHandler, server_version!()) .await .expect("Fail to establish client-side connection"); @@ -1047,10 +1088,12 @@ mod tests { .shutdown_timer(Arc::downgrade(&shutdown_timer)) .heartbeat_duration(Duration::from_millis(200)) .verifier(Arc::downgrade(&verifier)) + .version(server_version!()) .spawn(); // Spawn a task to handle the client-side establishment of a full connection - let _client_task = tokio::spawn(Connection::client(t2, DummyAuthHandler)); + let _client_task = + tokio::spawn(Connection::client(t2, DummyAuthHandler, server_version!())); // Shutdown server connection task while it is accepting the connection, verifying that we // do not get an error in return @@ -1099,10 +1142,12 @@ mod tests { .shutdown_timer(Arc::downgrade(&shutdown_timer)) .heartbeat_duration(Duration::from_millis(200)) .verifier(Arc::downgrade(&verifier)) + .version(server_version!()) .spawn(); // Spawn a task to handle the client-side establishment of a full connection - let _client_task = tokio::spawn(Connection::client(t2, DummyAuthHandler)); + let _client_task = + tokio::spawn(Connection::client(t2, DummyAuthHandler, server_version!())); // Wait to ensure we complete the accept call first let _ = rx.recv().await; diff --git a/distant-ssh2/src/lib.rs b/distant-ssh2/src/lib.rs index aefe3be..78a24ee 100644 --- a/distant-ssh2/src/lib.rs +++ b/distant-ssh2/src/lib.rs @@ -19,8 +19,9 @@ use async_compat::CompatExt; use async_trait::async_trait; use distant_core::net::auth::{AuthHandlerMap, DummyAuthHandler, Verifier}; use distant_core::net::client::{Client, ClientConfig}; -use distant_core::net::common::{Host, InmemoryTransport, OneshotListener}; +use distant_core::net::common::{Host, InmemoryTransport, OneshotListener, Version}; use distant_core::net::server::{Server, ServerRef}; +use distant_core::protocol::PROTOCOL_VERSION; use distant_core::{DistantApiServerHandler, DistantClient, DistantSingleKeyCredentials}; use log::*; use smol::channel::Receiver as SmolReceiver; @@ -588,6 +589,11 @@ impl Ssh { match Client::tcp(addr) .auth_handler(AuthHandlerMap::new().with_static_key(key.clone())) .connect_timeout(timeout) + .version(Version::new( + PROTOCOL_VERSION.major, + PROTOCOL_VERSION.minor, + PROTOCOL_VERSION.patch, + )) .connect() .await { @@ -756,6 +762,11 @@ impl Ssh { .auth_handler(DummyAuthHandler) .config(ClientConfig::default().with_maximum_silence_duration()) .connector(t1) + .version(Version::new( + PROTOCOL_VERSION.major, + PROTOCOL_VERSION.minor, + PROTOCOL_VERSION.patch, + )) .connect() .await?; Ok((client, server))