pull/146/head
Chip Senkbeil 2 years ago
parent 3e5741050f
commit 2192184b1d
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -595,7 +595,8 @@ mod tests {
local_data: &mut local_data,
},
)
.await;
.await
.unwrap();
let ctx = DistantCtx {
connection_id,
reply,

@ -365,21 +365,35 @@ impl Client<(), ()> {
ClientBuilder::new()
}
/// Creates a new [`TcpClientBuilder`].
pub fn tcp() -> TcpClientBuilder<()> {
TcpClientBuilder::new()
/// Creates a new [`ClientBuilder`] configured to use a [`TcpConnector`].
pub fn tcp<T>(connector: impl Into<TcpConnector<T>>) -> ClientBuilder<(), TcpConnector<T>> {
ClientBuilder::new().connector(connector.into())
}
/// Creates a new [`UnixSocketClientBuilder`].
/// Creates a new [`ClientBuilder`] configured to use a [`UnixSocketConnector`].
#[cfg(unix)]
pub fn unix_socket() -> UnixSocketClientBuilder<()> {
UnixSocketClientBuilder::new()
pub fn unix_socket(
connector: impl Into<UnixSocketConnector>,
) -> ClientBuilder<(), UnixSocketConnector> {
ClientBuilder::new().connector(connector.into())
}
/// Creates a new [`WindowsPipeClientBuilder`].
/// Creates a new [`ClientBuilder`] configured to use a local [`WindowsPipeConnector`].
#[cfg(windows)]
pub fn windows_pipe() -> WindowsPipeClientBuilder<()> {
WindowsPipeClientBuilder::new()
pub fn local_windows_pipe(
connector: impl Into<WindowsPipeConnector>,
) -> ClientBuilder<(), WindowsPipeConnector> {
let mut connector = connector.into();
connector.local = true;
ClientBuilder::new().connector(connector)
}
/// Creates a new [`ClientBuilder`] configured to use a [`WindowsPipeConnector`].
#[cfg(windows)]
pub fn windows_pipe(
connector: impl Into<WindowsPipeConnector>,
) -> ClientBuilder<(), WindowsPipeConnector> {
ClientBuilder::new().connector(connector.into())
}
}

@ -13,57 +13,61 @@ mod windows;
#[cfg(windows)]
pub use windows::*;
use crate::client::{Client, ReconnectStrategy};
use crate::client::{Client, ReconnectStrategy, UntypedClient};
use crate::common::{authentication::AuthHandler, Connection, Transport};
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, future::Future, io, time::Duration};
use async_trait::async_trait;
use std::{convert, io, time::Duration};
/// Builder for a [`Client`].
pub struct ClientBuilder<H, T> {
/// Interface that performs the connection to produce a [`Transport`] for use by the [`Client`].
#[async_trait]
pub trait Connector {
/// Type of transport produced by the connection.
type Transport: Transport + 'static;
async fn connect(self) -> io::Result<Self::Transport>;
}
#[async_trait]
impl<T: Transport + 'static> Connector for T {
type Transport = T;
async fn connect(self) -> io::Result<Self::Transport> {
Ok(self)
}
}
/// Builder for a [`Client`] or [`UntypedClient`].
pub struct ClientBuilder<H, C> {
auth_handler: H,
connector: C,
reconnect_strategy: ReconnectStrategy,
transport: T,
timeout: Option<Duration>,
}
impl<H, T> ClientBuilder<H, T> {
pub fn auth_handler<U>(self, auth_handler: U) -> ClientBuilder<U, T> {
impl<H, C> ClientBuilder<H, C> {
pub fn auth_handler<U>(self, auth_handler: U) -> ClientBuilder<U, C> {
ClientBuilder {
auth_handler,
connector: self.connector,
reconnect_strategy: self.reconnect_strategy,
transport: self.transport,
timeout: self.timeout,
}
}
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> ClientBuilder<H, T> {
pub fn connector<U>(self, connector: U) -> ClientBuilder<H, U> {
ClientBuilder {
auth_handler: self.auth_handler,
reconnect_strategy,
transport: self.transport,
connector,
reconnect_strategy: self.reconnect_strategy,
timeout: self.timeout,
}
}
pub async fn try_transport<U>(
self,
f: impl Future<Output = io::Result<U>>,
) -> io::Result<ClientBuilder<H, U>> {
let timeout = self.timeout.as_ref().copied();
Ok(self.transport(match timeout {
Some(duration) => tokio::time::timeout(duration, f)
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)?,
None => f.await?,
}))
}
pub fn transport<U>(self, transport: U) -> ClientBuilder<H, U> {
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> ClientBuilder<H, C> {
ClientBuilder {
auth_handler: self.auth_handler,
reconnect_strategy: self.reconnect_strategy,
transport,
connector: self.connector,
reconnect_strategy,
timeout: self.timeout,
}
}
@ -71,8 +75,8 @@ impl<H, T> ClientBuilder<H, T> {
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self {
auth_handler: self.auth_handler,
connector: self.connector,
reconnect_strategy: self.reconnect_strategy,
transport: self.transport,
timeout: timeout.into(),
}
}
@ -83,7 +87,7 @@ impl ClientBuilder<(), ()> {
Self {
auth_handler: (),
reconnect_strategy: ReconnectStrategy::default(),
transport: (),
connector: (),
timeout: None,
}
}
@ -95,27 +99,29 @@ impl Default for ClientBuilder<(), ()> {
}
}
impl<H, T> ClientBuilder<H, T>
impl<H, C> ClientBuilder<H, C>
where
H: AuthHandler + Send,
T: Transport + 'static,
C: Connector,
{
/// Establishes a connection with a remote server using the configured [`Transport`]
/// and other settings, returning a new [`Client`] instance once the connection
/// and other settings, returning a new [`UntypedClient`] instance once the connection
/// is fully established and authenticated.
pub async fn connect<U, V>(self) -> io::Result<Client<U, V>>
where
U: Send + Sync + Serialize + 'static,
V: Send + Sync + DeserializeOwned + 'static,
{
pub async fn connect_untyped(self) -> io::Result<UntypedClient> {
let auth_handler = self.auth_handler;
let retry_strategy = self.reconnect_strategy;
let timeout = self.timeout;
let transport = self.transport;
let f = async move {
let transport = match timeout {
Some(duration) => tokio::time::timeout(duration, self.connector.connect())
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)?,
None => self.connector.connect().await?,
};
let connection = Connection::client(transport, auth_handler).await?;
Ok(Client::spawn(connection, retry_strategy))
Ok(UntypedClient::spawn(connection, retry_strategy))
};
match timeout {
@ -126,4 +132,11 @@ where
None => f.await,
}
}
/// Establishes a connection with a remote server using the configured [`Transport`] and other
/// settings, returning a new [`Client`] instance once the connection is fully established and
/// authenticated.
pub async fn connect<T, U>(self) -> io::Result<Client<T, U>> {
Ok(self.connect_untyped().await?.into_typed_client())
}
}

@ -1,47 +1,31 @@
use crate::client::{Client, ClientBuilder, ReconnectStrategy};
use crate::common::{authentication::AuthHandler, TcpTransport};
use serde::{de::DeserializeOwned, Serialize};
use tokio::{io, net::ToSocketAddrs, time::Duration};
use super::Connector;
use crate::common::TcpTransport;
use async_trait::async_trait;
use std::io;
use tokio::net::ToSocketAddrs;
/// Builder for a client that will connect over TCP
pub struct TcpClientBuilder<T>(ClientBuilder<T, ()>);
impl<T> TcpClientBuilder<T> {
pub fn auth_handler<A: AuthHandler>(self, auth_handler: A) -> TcpClientBuilder<A> {
TcpClientBuilder(self.0.auth_handler(auth_handler))
}
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> TcpClientBuilder<T> {
TcpClientBuilder(self.0.reconnect_strategy(reconnect_strategy))
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self(self.0.timeout(timeout))
}
/// Implementation of [`Connector`] to support connecting via TCP.
pub struct TcpConnector<T> {
addr: T,
}
impl TcpClientBuilder<()> {
pub fn new() -> Self {
Self(ClientBuilder::new())
impl<T> TcpConnector<T> {
pub fn new(addr: T) -> Self {
Self { addr }
}
}
impl Default for TcpClientBuilder<()> {
fn default() -> Self {
Self::new()
impl<T> From<T> for TcpConnector<T> {
fn from(addr: T) -> Self {
Self::new(addr)
}
}
impl<A: AuthHandler + Send> TcpClientBuilder<A> {
pub async fn connect<T, U>(self, addr: impl ToSocketAddrs) -> io::Result<Client<T, U>>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
self.0
.try_transport(TcpTransport::connect(addr))
.await?
.connect()
.await
#[async_trait]
impl<T: ToSocketAddrs + Send> Connector for TcpConnector<T> {
type Transport = TcpTransport;
async fn connect(self) -> io::Result<Self::Transport> {
TcpTransport::connect(self.addr).await
}
}

@ -1,51 +1,30 @@
use crate::client::{Client, ClientBuilder, ReconnectStrategy};
use crate::common::{authentication::AuthHandler, UnixSocketTransport};
use serde::{de::DeserializeOwned, Serialize};
use std::path::Path;
use tokio::{io, time::Duration};
use super::Connector;
use crate::common::UnixSocketTransport;
use async_trait::async_trait;
use std::{io, path::PathBuf};
/// Builder for a client that will connect over a Unix socket
pub struct UnixSocketClientBuilder<T>(ClientBuilder<T, ()>);
impl<T> UnixSocketClientBuilder<T> {
pub fn auth_handler<A: AuthHandler>(self, auth_handler: A) -> UnixSocketClientBuilder<A> {
UnixSocketClientBuilder(self.0.auth_handler(auth_handler))
}
pub fn reconnect_strategy(
self,
reconnect_strategy: ReconnectStrategy,
) -> UnixSocketClientBuilder<T> {
UnixSocketClientBuilder(self.0.reconnect_strategy(reconnect_strategy))
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self(self.0.timeout(timeout))
}
/// Implementation of [`Connector`] to support connecting via a Unix socket.
pub struct UnixSocketConnector {
path: PathBuf,
}
impl UnixSocketClientBuilder<()> {
pub fn new() -> Self {
Self(ClientBuilder::new())
impl UnixSocketConnector {
pub fn new(path: impl Into<PathBuf>) -> Self {
Self { path: path.into() }
}
}
impl Default for UnixSocketClientBuilder<()> {
fn default() -> Self {
Self::new()
impl<T: Into<PathBuf>> From<T> for UnixSocketConnector {
fn from(path: T) -> Self {
Self::new(path)
}
}
impl<A: AuthHandler + Send> UnixSocketClientBuilder<A> {
pub async fn connect<T, U>(self, path: impl AsRef<Path> + Send) -> io::Result<Client<T, U>>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
self.0
.try_transport(UnixSocketTransport::connect(path.as_ref()))
.await?
.connect()
.await
#[async_trait]
impl Connector for UnixSocketConnector {
type Transport = UnixSocketTransport;
async fn connect(self) -> io::Result<Self::Transport> {
UnixSocketTransport::connect(self.path).await
}
}

@ -1,76 +1,44 @@
use crate::client::{Client, ClientBuilder, ReconnectStrategy};
use crate::common::{authentication::AuthHandler, WindowsPipeTransport};
use serde::{de::DeserializeOwned, Serialize};
use super::Connector;
use crate::common::WindowsPipeTransport;
use async_trait::async_trait;
use std::ffi::{OsStr, OsString};
use tokio::{io, time::Duration};
use std::io;
/// Builder for a client that will connect over a Windows pipe
pub struct WindowsPipeClientBuilder<T> {
inner: ClientBuilder<T, ()>,
local: bool,
/// Implementation of [`Connector`] to support connecting via a Windows named pipe.
pub struct WindowsPipeConnector {
addr: OsString,
pub(crate) local: bool,
}
impl<T> WindowsPipeClientBuilder<T> {
pub fn auth_handler<A: AuthHandler>(self, auth_handler: A) -> WindowsPipeClientBuilder<A> {
WindowsPipeClientBuilder {
inner: self.inner.auth_handler(auth_handler),
local: self.local,
}
impl WindowsPipeConnector {
/// Creates a new connector for a non-local pipe using the given `addr`.
pub fn new(addr: impl Into<OsString>) -> Self {
Self { addr: addr.into(), local: false }
}
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> WindowsPipeClientBuilder<T> {
WindowsPipeClientBuilder(self.0.reconnect_strategy(reconnect_strategy))
}
/// If true, will connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}`; otherwise, will connect using the address verbatim.
pub fn local(self, local: bool) -> Self {
Self {
inner: self.inner,
local,
}
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self {
inner: self.inner.timeout(timeout),
local: self.local,
}
/// Creates a new connector for a local pipe using the given `name`.
pub fn local(name: impl Into<OsString>) -> Self {
Self { addr: name.into(), local: true }
}
}
impl WindowsPipeClientBuilder<()> {
pub fn new() -> Self {
Self {
inner: ClientBuilder::new(),
local: false,
}
impl<T: Into<OsString>> From<T> for WindowsPipeConnector {
fn from(addr: T) -> Self {
Self::new(path)
}
}
impl Default for WindowsPipeClientBuilder<()> {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Connector for WindowsPipeConnector {
type Transport = WindowsPipeTransport;
impl<A: AuthHandler + Send> WindowsPipeClientBuilder<A> {
pub async fn connect<T, U>(self, addr: impl AsRef<OsStr> + Send) -> io::Result<Client<T, U>>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
let local = self.local;
self.0
.try_transport(if local {
async fn connect(self) -> io::Result<Self::Transport> {
WindowsPipeTransport::connect(if local {
let mut full_addr = OsString::from(r"\\.\pipe\");
full_addr.push(addr.as_ref());
WindowsPipeTransport::connect(full_addr)
} else {
WindowsPipeTransport::connect(addr.as_ref())
})
.await?
.connect()
.await
}).await
}
}

@ -58,31 +58,26 @@ where
self.inner
}
/// Assigns a default mailbox for any response received that does not match another mailbox.
pub async fn assign_default_mailbox(&self, buffer: usize) -> io::Result<Mailbox<Response<U>>> {
Ok(map_to_typed_mailbox(
self.inner.assign_default_mailbox(buffer).await?,
))
}
/// Removes the default mailbox used for unmatched responses such that any response without a
/// matching mailbox will be dropped.
pub async fn remove_default_mailbox(&self) -> io::Result<()> {
self.inner.remove_default_mailbox().await
}
/// Sends a request and returns a mailbox that can receive one or more responses, failing if
/// unable to send a request or if the session's receiving line to the remote server has
/// already been severed
pub async fn mail(&mut self, req: impl Into<Request<T>>) -> io::Result<Mailbox<Response<U>>> {
Ok(self
.inner
.mail(req.into().to_untyped_request()?)
.await?
.map_opt(|res| match res.to_typed_response() {
Ok(res) => Some(res),
Err(x) => {
if log::log_enabled!(Level::Trace) {
trace!(
"Invalid response payload: {}",
String::from_utf8_lossy(&res.payload)
);
}
error!(
"Unable to parse response payload into {}: {x}",
std::any::type_name::<U>()
);
None
}
}))
Ok(map_to_typed_mailbox(
self.inner.mail(req.into().to_untyped_request()?).await?,
))
}
/// Sends a request and returns a mailbox, timing out after duration has passed
@ -150,6 +145,28 @@ where
}
}
fn map_to_typed_mailbox<T: Send + DeserializeOwned + 'static>(
mailbox: Mailbox<UntypedResponse<'static>>,
) -> Mailbox<Response<T>> {
mailbox.map_opt(|res| match res.to_typed_response() {
Ok(res) => Some(res),
Err(x) => {
if log::log_enabled!(Level::Trace) {
trace!(
"Invalid response payload: {}",
String::from_utf8_lossy(&res.payload)
);
}
error!(
"Unable to parse response payload into {}: {x}",
std::any::type_name::<T>()
);
None
}
})
}
/// Represents a sender of requests tied to a session, holding onto a weak reference of
/// mailboxes to relay responses, meaning that once the [`Client`] is closed or dropped,
/// any sent request will no longer be able to receive responses.
@ -192,6 +209,32 @@ impl UntypedChannel {
}
}
/// Assigns a default mailbox for any response received that does not match another mailbox.
pub async fn assign_default_mailbox(
&self,
buffer: usize,
) -> io::Result<Mailbox<UntypedResponse<'static>>> {
match Weak::upgrade(&self.post_office) {
Some(post_office) => Ok(post_office.assign_default_mailbox(buffer).await),
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"Channel's post office is no longer available",
)),
}
}
/// Removes the default mailbox used for unmatched responses such that any response without a
/// matching mailbox will be dropped.
pub async fn remove_default_mailbox(&self) -> io::Result<()> {
match Weak::upgrade(&self.post_office) {
Some(post_office) => Ok(post_office.remove_default_mailbox().await),
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"Channel's post office is no longer available",
)),
}
}
/// Sends a request and returns a mailbox that can receive one or more responses, failing if
/// unable to send a request or if the session's receiving line to the remote server has
/// already been severed

@ -7,13 +7,14 @@ use std::{
};
use tokio::{
io,
sync::{mpsc, Mutex},
sync::{mpsc, Mutex, RwLock},
time,
};
#[derive(Clone, Debug)]
pub struct PostOffice<T> {
mailboxes: Arc<Mutex<HashMap<Id, mpsc::Sender<T>>>>,
default_box: Arc<RwLock<Option<mpsc::Sender<T>>>>,
}
impl<T> Default for PostOffice<T>
@ -52,7 +53,10 @@ where
}
});
Self { mailboxes }
Self {
mailboxes,
default_box: Arc::new(RwLock::new(None)),
}
}
/// Creates a new mailbox using the given id and buffer size for maximum values that
@ -79,11 +83,37 @@ where
}
success
} else if let Some(tx) = self.default_box.read().await.as_ref() {
tx.send(value).await.is_ok()
} else {
false
}
}
/// Creates a new default mailbox that will be used whenever no mailbox is found to deliver
/// mail. This will replace any existing default mailbox.
pub async fn assign_default_mailbox(&self, buffer: usize) -> Mailbox<T> {
let (tx, rx) = mpsc::channel(buffer);
*self.default_box.write().await = Some(tx);
Mailbox {
id: "".to_string(),
rx: Box::new(rx),
}
}
/// Removes the default mailbox such that any mail without a matching mailbox will be dropped
/// instead of being delivered to a default mailbox.
pub async fn remove_default_mailbox(&self) {
*self.default_box.write().await = None;
}
/// Returns true if the post office is using a default mailbox for all mail that does not map
/// to another mailbox.
pub async fn has_default_mailbox(&self) -> bool {
self.default_box.read().await.is_some()
}
/// Cancels delivery to the mailbox with the specified `id`.
pub async fn cancel(&self, id: &Id) {
self.mailboxes.lock().await.remove(id);

@ -1,5 +1,5 @@
use crate::{
client::{Client, ClientBuilder},
client::Client,
common::{
authentication::{
msg::{Authentication, AuthenticationResponse},

@ -1,5 +1,5 @@
use crate::{
client::{Client, ReconnectStrategy},
client::{Client, ReconnectStrategy, UntypedClient},
common::{authentication::AuthHandler, Connection, ConnectionId, InmemoryTransport},
manager::data::{ManagerRequest, ManagerResponse},
};
@ -46,6 +46,24 @@ impl RawChannel {
let connection = Connection::client(self.transport, handler).await?;
Ok(Client::spawn(connection, ReconnectStrategy::Fail))
}
/// Consumes this channel, returning an untyped client wrapping the transport.
///
/// ### Note
///
/// This will perform necessary handshakes and authentication (via `handler`) with the server.
///
/// Because the underlying transport maps to the same, singular connection with the manager
/// of servers, the reconnect strategy is set up to fail immediately as the actual reconnect
/// logic is handled by the primary client connection with the manager, not the connection
/// with a proxied server.
pub async fn spawn_untyped_client(
self,
handler: impl AuthHandler + Send,
) -> io::Result<UntypedClient> {
let connection = Connection::client(self.transport, handler).await?;
Ok(UntypedClient::spawn(connection, ReconnectStrategy::Fail))
}
}
impl Deref for RawChannel {

@ -124,7 +124,7 @@ impl ManagerServer {
.await?
};
let connection = ManagerConnection::new(destination, options, client);
let connection = ManagerConnection::spawn(destination, options, client).await?;
let id = connection.id;
self.connections.write().await.insert(id, connection);
Ok(id)
@ -312,7 +312,8 @@ impl ServerHandler for ManagerServer {
#[cfg(test)]
mod tests {
use super::*;
use crate::common::{FramedTransport, Transport};
use crate::client::{ReconnectStrategy, UntypedClient};
use crate::common::FramedTransport;
use crate::server::ServerReply;
use crate::{boxed_connect_handler, boxed_launch_handler};
use tokio::sync::mpsc;
@ -328,9 +329,9 @@ mod tests {
}
}
/// Create a framed transport that is detached such that reads and writes will fail
fn detached_framed_transport() -> FramedTransport<Box<dyn Transport>> {
FramedTransport::pair(1).0.into_boxed()
/// Create an untyped client that is detached such that reads and writes will fail
fn detached_untyped_client() -> UntypedClient {
UntypedClient::spawn_inmemory(FramedTransport::pair(1).0, ReconnectStrategy::Fail)
}
/// Create a new server and authenticator
@ -452,7 +453,7 @@ mod tests {
async fn connect_should_return_id_of_new_connection_on_success() {
let mut config = test_config();
let handler = boxed_connect_handler!(|_a, _b, _c| { Ok(detached_framed_transport()) });
let handler = boxed_connect_handler!(|_a, _b, _c| { Ok(detached_untyped_client()) });
config
.connect_handlers
@ -485,11 +486,13 @@ mod tests {
async fn info_should_return_information_about_established_connection() {
let (server, _) = setup(test_config());
let connection = ManagerConnection::new(
let connection = ManagerConnection::spawn(
"scheme://host".parse().unwrap(),
"key=value".parse().unwrap(),
detached_framed_transport(),
);
detached_untyped_client(),
)
.await
.unwrap();
let id = connection.id;
server.connections.write().await.insert(id, connection);
@ -516,19 +519,23 @@ mod tests {
async fn list_should_return_a_list_of_established_connections() {
let (server, _) = setup(test_config());
let connection = ManagerConnection::new(
let connection = ManagerConnection::spawn(
"scheme://host".parse().unwrap(),
"key=value".parse().unwrap(),
detached_framed_transport(),
);
detached_untyped_client(),
)
.await
.unwrap();
let id_1 = connection.id;
server.connections.write().await.insert(id_1, connection);
let connection = ManagerConnection::new(
let connection = ManagerConnection::spawn(
"other://host2".parse().unwrap(),
"key=value".parse().unwrap(),
detached_framed_transport(),
);
detached_untyped_client(),
)
.await
.unwrap();
let id_2 = connection.id;
server.connections.write().await.insert(id_2, connection);
@ -555,11 +562,13 @@ mod tests {
async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() {
let (server, _) = setup(test_config());
let connection = ManagerConnection::new(
let connection = ManagerConnection::spawn(
"scheme://host".parse().unwrap(),
"key=value".parse().unwrap(),
detached_framed_transport(),
);
detached_untyped_client(),
)
.await
.unwrap();
let id = connection.id;
server.connections.write().await.insert(id, connection);

@ -1,14 +1,11 @@
use crate::{
client::UntypedClient,
common::{
ConnectionId, Destination, FramedTransport, Interest, Map, Transport, UntypedRequest,
UntypedResponse,
},
client::{Mailbox, UntypedClient},
common::{ConnectionId, Destination, Map, UntypedRequest, UntypedResponse},
manager::data::{ManagerChannelId, ManagerResponse},
server::ServerReply,
};
use log::*;
use std::{collections::HashMap, io, time::Duration};
use std::{collections::HashMap, io};
use tokio::{sync::mpsc, task::JoinHandle};
/// Represents a connection a distant manager has with some distant-compatible server
@ -17,8 +14,10 @@ pub struct ManagerConnection {
pub destination: Destination,
pub options: Map,
tx: mpsc::UnboundedSender<Action>,
transport_task: JoinHandle<()>,
action_task: JoinHandle<()>,
request_task: JoinHandle<()>,
response_task: JoinHandle<()>,
}
#[derive(Clone)]
@ -34,10 +33,14 @@ impl ManagerChannel {
pub fn send(&self, data: Vec<u8>) -> io::Result<()> {
let channel_id = self.channel_id;
let req = UntypedRequest::from_slice(&data)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?
.into_owned();
self.tx
.send(Action::Write {
id: channel_id,
data,
req,
})
.map_err(|x| {
io::Error::new(
@ -61,28 +64,32 @@ impl ManagerChannel {
}
impl ManagerConnection {
pub fn new(destination: Destination, options: Map, client: UntypedClient) -> Self {
pub async fn spawn(
spawn: Destination,
options: Map,
client: UntypedClient,
) -> io::Result<Self> {
let connection_id = rand::random();
let (tx, rx) = mpsc::unbounded_channel();
let (outgoing_tx, outgoing_rx) = mpsc::unbounded_channel();
let transport_task = tokio::spawn(transport_task(
let (request_tx, request_rx) = mpsc::unbounded_channel();
let action_task = tokio::spawn(action_task(connection_id, rx, request_tx));
let response_task = tokio::spawn(response_task(
connection_id,
client,
outgoing_rx,
client.assign_default_mailbox(100).await?,
tx.clone(),
Duration::from_millis(50),
));
let action_task = tokio::spawn(action_task(connection_id, rx, outgoing_tx));
let request_task = tokio::spawn(request_task(connection_id, client, request_rx));
Self {
Ok(Self {
id: connection_id,
destination,
destination: spawn,
options,
tx,
transport_task,
action_task,
}
request_task,
response_task,
})
}
pub fn open_channel(&self, reply: ServerReply<ManagerResponse>) -> io::Result<ManagerChannel> {
@ -107,8 +114,9 @@ impl ManagerConnection {
impl Drop for ManagerConnection {
fn drop(&mut self) {
self.transport_task.abort();
self.action_task.abort();
self.request_task.abort();
self.response_task.abort();
}
}
@ -123,91 +131,37 @@ enum Action {
},
Read {
data: Vec<u8>,
res: UntypedResponse<'static>,
},
Write {
id: ManagerChannelId,
data: Vec<u8>,
req: UntypedRequest<'static>,
},
}
/// Internal task to read and write from a [`Transport`].
///
/// * `id` - the id of the connection.
/// * `transport` - the fully-authenticated transport.
/// * `rx` - used to receive outgoing data to send through the connection.
/// * `tx` - used to send new [`Action`]s to process.
async fn transport_task(
/// Internal task to process outgoing [`UntypedRequest`]s.
async fn request_task(
id: ConnectionId,
mut client: UntypedClient,
mut rx: mpsc::UnboundedReceiver<Vec<u8>>,
tx: mpsc::UnboundedSender<Action>,
sleep_duration: Duration,
mut rx: mpsc::UnboundedReceiver<UntypedRequest<'static>>,
) {
loop {
let ready = match client.ready(Interest::READABLE | Interest::WRITABLE).await {
Ok(ready) => ready,
Err(x) => {
error!("[Conn {id}] Querying ready status failed: {x}");
break;
}
};
// Keep track of whether we read or wrote anything
let mut read_blocked = !ready.is_readable();
let mut write_blocked = !ready.is_writable();
// If transport is readable, attempt to read a frame and forward it to our action task
if ready.is_readable() {
match client.try_read_frame() {
Ok(Some(frame)) => {
if let Err(x) = tx.send(Action::Read {
data: frame.into_item().into_owned(),
}) {
error!("[Conn {id}] Failed to forward frame: {x}");
}
}
Ok(None) => {
debug!("[Conn {id}] Connection closed");
break;
}
Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
Err(x) => {
error!("[Conn {id}] {x}");
}
}
}
// If transport is writable, check if we have something to write
if ready.is_writable() {
if let Ok(data) = rx.try_recv() {
match client.try_write_frame(data) {
Ok(()) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => error!("[Conn {id}] Send failed: {x}"),
}
} else {
// In the case of flushing, there are two scenarios in which we want to
// mark no write occurring:
//
// 1. When flush did not write any bytes, which can happen when the buffer
// is empty
// 2. When the call to write bytes blocks
match client.try_flush() {
Ok(0) => write_blocked = true,
Ok(_) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => {
error!("[Conn {id}] {x}");
}
}
}
while let Some(req) = rx.recv().await {
if let Err(x) = client.fire(req).await {
error!("[Conn {id}] Failed to send request: {x}");
}
}
}
// If we did not read or write anything, sleep a bit to offload CPU usage
if read_blocked && write_blocked {
tokio::time::sleep(sleep_duration).await;
/// Internal task to process incoming [`UntypedResponse`]s.
async fn response_task(
id: ConnectionId,
mut mailbox: Mailbox<UntypedResponse<'static>>,
tx: mpsc::UnboundedSender<Action>,
) {
while let Some(res) = mailbox.next().await {
if let Err(x) = tx.send(Action::Read { res }) {
error!("[Conn {id}] Failed to forward received response: {x}");
}
}
}
@ -216,11 +170,11 @@ async fn transport_task(
///
/// * `id` - the id of the connection.
/// * `rx` - used to receive new [`Action`]s to process.
/// * `tx` - used to send outgoing data through the connection.
/// * `tx` - used to send outgoing requests through the connection.
async fn action_task(
id: ConnectionId,
mut rx: mpsc::UnboundedReceiver<Action>,
tx: mpsc::UnboundedSender<Vec<u8>>,
tx: mpsc::UnboundedSender<UntypedRequest<'static>>,
) {
let mut registered = HashMap::new();
@ -232,22 +186,13 @@ async fn action_task(
Action::Unregister { id } => {
registered.remove(&id);
}
Action::Read { data } => {
// Partially parse data into a request so we can modify the origin id
let mut response = match UntypedResponse::from_slice(&data) {
Ok(response) => response,
Err(x) => {
error!("[Conn {id}] Failed to parse response during read: {x}");
continue;
}
};
Action::Read { mut res } => {
// Split {channel id}_{request id} back into pieces and
// update the origin id to match the request id only
let channel_id = match response.origin_id.split_once('_') {
let channel_id = match res.origin_id.split_once('_') {
Some((cid_str, oid_str)) => {
if let Ok(cid) = cid_str.parse::<ManagerChannelId>() {
response.set_origin_id(oid_str.to_string());
res.set_origin_id(oid_str.to_string());
cid
} else {
continue;
@ -259,28 +204,19 @@ async fn action_task(
if let Some(reply) = registered.get(&channel_id) {
let response = ManagerResponse::Channel {
id: channel_id,
data: response.to_bytes(),
data: res.to_bytes(),
};
if let Err(x) = reply.send(response).await {
error!("[Conn {id}] {x}");
}
}
}
Action::Write { id, data } => {
// Partially parse data into a request so we can modify the id
let mut request = match UntypedRequest::from_slice(&data) {
Ok(request) => request,
Err(x) => {
error!("[Conn {id}] Failed to parse request during write: {x}");
continue;
}
};
Action::Write { id, mut req } => {
// Combine channel id with request id so we can properly forward
// the response containing this in the origin id
request.set_id(format!("{id}_{}", request.id));
req.set_id(format!("{id}_{}", req.id));
if let Err(x) = tx.send(request.to_bytes()) {
if let Err(x) = tx.send(req) {
error!("[Conn {id}] {x}");
}
}

@ -86,11 +86,12 @@ mod tests {
.await
.expect("Failed to start TCP server");
let mut client: Client<String, String> = Client::tcp()
.auth_handler(DummyAuthHandler)
.connect(SocketAddr::from((server.ip_addr(), server.port())))
.await
.expect("Client failed to connect");
let mut client: Client<String, String> =
Client::tcp(SocketAddr::from((server.ip_addr(), server.port())))
.auth_handler(DummyAuthHandler)
.connect()
.await
.expect("Client failed to connect");
let response = client
.send(Request::new("hello".to_string()))

@ -93,9 +93,9 @@ mod tests {
.await
.expect("Failed to start Unix socket server");
let mut client: Client<String, String> = Client::unix_socket()
let mut client: Client<String, String> = Client::unix_socket(server.path())
.auth_handler(DummyAuthHandler)
.connect(server.path())
.connect()
.await
.expect("Client failed to connect");

@ -2,7 +2,7 @@ use async_trait::async_trait;
use distant_net::boxed_connect_handler;
use distant_net::client::{Client, ReconnectStrategy};
use distant_net::common::authentication::{DummyAuthHandler, Verifier};
use distant_net::common::{Destination, FramedTransport, InmemoryTransport, Map, OneshotListener};
use distant_net::common::{Destination, InmemoryTransport, Map, OneshotListener};
use distant_net::manager::{Config, ManagerClient, ManagerServer};
use distant_net::server::{Server, ServerCtx, ServerHandler};
use std::io;
@ -42,23 +42,23 @@ async fn should_be_able_to_establish_a_single_connection_and_communicate() {
let client = Client::build()
.auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.transport(t1)
.connect()
.connector(t1)
.connect_untyped()
.await?;
Ok(client)
}),
);
let manager_ref = ManagerServer::new(Config::default())
let _manager_ref = ManagerServer::new(Config::default())
.verifier(Verifier::none())
.start(OneshotListener::from_value(t2))
.expect("Failed to start manager server");
let mut client = ManagerClient::build()
let mut client: ManagerClient = Client::build()
.auth_handler(DummyAuthHandler)
.reconnect_strategy(ReconnectStrategy::Fail)
.transport(t1)
.connector(t1)
.connect()
.await
.expect("Failed to connect to manager");
@ -91,14 +91,18 @@ async fn should_be_able_to_establish_a_single_connection_and_communicate() {
assert_eq!(info.options, "key=value".parse::<Map>().unwrap());
// Create a new channel and request some data
let mut channel = client
let mut channel_client: Client<u8, u8> = client
.open_raw_channel(id)
.await
.expect("Failed to open channel");
let _ = channel
.system_info()
.expect("Failed to open channel")
.spawn_client(DummyAuthHandler)
.await
.expect("Failed to get system information");
.expect("Failed to spawn client for channel");
let res = channel_client
.send(123u8)
.await
.expect("Failed to send request to server");
assert_eq!(res.payload, 123u8, "Invalid response payload");
// Test killing a connection
client.kill(id).await.expect("Failed to kill connection");

Loading…
Cancel
Save