mirror of https://github.com/chipsenkbeil/distant
Rewrite to support custom authentication, handshakes for encryption/compression, and reconnecting (#146)
parent
7d1b3ba6f0
commit
4798b67dfe
@ -1,6 +1,6 @@
|
||||
[profile.ci]
|
||||
fail-fast = false
|
||||
retries = 2
|
||||
retries = 4
|
||||
slow-timeout = { period = "60s", terminate-after = 3 }
|
||||
status-level = "fail"
|
||||
final-status-level = "fail"
|
||||
|
@ -1,783 +0,0 @@
|
||||
use super::data::{
|
||||
ConnectionId, ConnectionInfo, ConnectionList, Destination, ManagerCapabilities, ManagerRequest,
|
||||
ManagerResponse,
|
||||
};
|
||||
use crate::{
|
||||
DistantChannel, DistantClient, DistantMsg, DistantRequestData, DistantResponseData, Map,
|
||||
};
|
||||
use distant_net::{
|
||||
router, Auth, AuthServer, Client, IntoSplit, MpscTransport, OneshotListener, Request, Response,
|
||||
ServerExt, ServerRef, UntypedTransportRead, UntypedTransportWrite,
|
||||
};
|
||||
use log::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
mod config;
|
||||
pub use config::*;
|
||||
|
||||
mod ext;
|
||||
pub use ext::*;
|
||||
|
||||
router!(DistantManagerClientRouter {
|
||||
auth_transport: Request<Auth> => Response<Auth>,
|
||||
manager_transport: Response<ManagerResponse> => Request<ManagerRequest>,
|
||||
});
|
||||
|
||||
/// Represents a client that can connect to a remote distant manager
|
||||
pub struct DistantManagerClient {
|
||||
auth: Box<dyn ServerRef>,
|
||||
client: Client<ManagerRequest, ManagerResponse>,
|
||||
distant_clients: HashMap<ConnectionId, ClientHandle>,
|
||||
}
|
||||
|
||||
impl Drop for DistantManagerClient {
|
||||
fn drop(&mut self) {
|
||||
self.auth.abort();
|
||||
self.client.abort();
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a raw channel between a manager client and some remote server
|
||||
pub struct RawDistantChannel {
|
||||
pub transport: MpscTransport<
|
||||
Request<DistantMsg<DistantRequestData>>,
|
||||
Response<DistantMsg<DistantResponseData>>,
|
||||
>,
|
||||
forward_task: JoinHandle<()>,
|
||||
mailbox_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl RawDistantChannel {
|
||||
pub fn abort(&self) {
|
||||
self.forward_task.abort();
|
||||
self.mailbox_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RawDistantChannel {
|
||||
type Target = MpscTransport<
|
||||
Request<DistantMsg<DistantRequestData>>,
|
||||
Response<DistantMsg<DistantResponseData>>,
|
||||
>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.transport
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for RawDistantChannel {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.transport
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientHandle {
|
||||
client: DistantClient,
|
||||
forward_task: JoinHandle<()>,
|
||||
mailbox_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Drop for ClientHandle {
|
||||
fn drop(&mut self) {
|
||||
self.forward_task.abort();
|
||||
self.mailbox_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl DistantManagerClient {
|
||||
/// Initializes a client using the provided [`UntypedTransport`]
|
||||
pub fn new<T>(config: DistantManagerClientConfig, transport: T) -> io::Result<Self>
|
||||
where
|
||||
T: IntoSplit + 'static,
|
||||
T::Read: UntypedTransportRead + 'static,
|
||||
T::Write: UntypedTransportWrite + 'static,
|
||||
{
|
||||
let DistantManagerClientRouter {
|
||||
auth_transport,
|
||||
manager_transport,
|
||||
..
|
||||
} = DistantManagerClientRouter::new(transport);
|
||||
|
||||
// Initialize our client with manager request/response transport
|
||||
let (writer, reader) = manager_transport.into_split();
|
||||
let client = Client::new(writer, reader)?;
|
||||
|
||||
// Initialize our auth handler with auth/auth transport
|
||||
let auth = AuthServer {
|
||||
on_challenge: config.on_challenge,
|
||||
on_verify: config.on_verify,
|
||||
on_info: config.on_info,
|
||||
on_error: config.on_error,
|
||||
}
|
||||
.start(OneshotListener::from_value(auth_transport.into_split()))?;
|
||||
|
||||
Ok(Self {
|
||||
auth,
|
||||
client,
|
||||
distant_clients: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Request that the manager launches a new server at the given `destination`
|
||||
/// with `options` being passed for destination-specific details, returning the new
|
||||
/// `destination` of the spawned server to connect to
|
||||
pub async fn launch(
|
||||
&mut self,
|
||||
destination: impl Into<Destination>,
|
||||
options: impl Into<Map>,
|
||||
) -> io::Result<Destination> {
|
||||
let destination = Box::new(destination.into());
|
||||
let options = options.into();
|
||||
trace!("launch({}, {})", destination, options);
|
||||
|
||||
let res = self
|
||||
.client
|
||||
.send(ManagerRequest::Launch {
|
||||
destination,
|
||||
options,
|
||||
})
|
||||
.await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Launched { destination } => Ok(destination),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Request that the manager establishes a new connection at the given `destination`
|
||||
/// with `options` being passed for destination-specific details
|
||||
pub async fn connect(
|
||||
&mut self,
|
||||
destination: impl Into<Destination>,
|
||||
options: impl Into<Map>,
|
||||
) -> io::Result<ConnectionId> {
|
||||
let destination = Box::new(destination.into());
|
||||
let options = options.into();
|
||||
trace!("connect({}, {})", destination, options);
|
||||
|
||||
let res = self
|
||||
.client
|
||||
.send(ManagerRequest::Connect {
|
||||
destination,
|
||||
options,
|
||||
})
|
||||
.await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Connected { id } => Ok(id),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Establishes a channel with the server represented by the `connection_id`,
|
||||
/// returning a [`DistantChannel`] acting as the connection
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Multiple calls to open a channel against the same connection will result in
|
||||
/// clones of the same [`DistantChannel`] rather than establishing a duplicate
|
||||
/// remote connection to the same server
|
||||
pub async fn open_channel(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
) -> io::Result<DistantChannel> {
|
||||
trace!("open_channel({})", connection_id);
|
||||
if let Some(handle) = self.distant_clients.get(&connection_id) {
|
||||
Ok(handle.client.clone_channel())
|
||||
} else {
|
||||
let RawDistantChannel {
|
||||
transport,
|
||||
forward_task,
|
||||
mailbox_task,
|
||||
} = self.open_raw_channel(connection_id).await?;
|
||||
let (writer, reader) = transport.into_split();
|
||||
let client = DistantClient::new(writer, reader)?;
|
||||
let channel = client.clone_channel();
|
||||
self.distant_clients.insert(
|
||||
connection_id,
|
||||
ClientHandle {
|
||||
client,
|
||||
forward_task,
|
||||
mailbox_task,
|
||||
},
|
||||
);
|
||||
Ok(channel)
|
||||
}
|
||||
}
|
||||
|
||||
/// Establishes a channel with the server represented by the `connection_id`,
|
||||
/// returning a [`Transport`] acting as the connection
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Multiple calls to open a channel against the same connection will result in establishing a
|
||||
/// duplicate remote connections to the same server, so take care when using this method
|
||||
pub async fn open_raw_channel(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
) -> io::Result<RawDistantChannel> {
|
||||
trace!("open_raw_channel({})", connection_id);
|
||||
let mut mailbox = self
|
||||
.client
|
||||
.mail(ManagerRequest::OpenChannel { id: connection_id })
|
||||
.await?;
|
||||
|
||||
// Wait for the first response, which should be channel confirmation
|
||||
let channel_id = match mailbox.next().await {
|
||||
Some(response) => match response.payload {
|
||||
ManagerResponse::ChannelOpened { id } => Ok(id),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
},
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"open_channel mailbox aborted",
|
||||
)),
|
||||
}?;
|
||||
|
||||
// Spawn reader and writer tasks to forward requests and replies
|
||||
// using our opened channel
|
||||
let (t1, t2) = MpscTransport::pair(1);
|
||||
let (mut writer, mut reader) = t1.into_split();
|
||||
let mailbox_task = tokio::spawn(async move {
|
||||
use distant_net::TypedAsyncWrite;
|
||||
while let Some(response) = mailbox.next().await {
|
||||
match response.payload {
|
||||
ManagerResponse::Channel { response, .. } => {
|
||||
if let Err(x) = writer.write(response).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
ManagerResponse::ChannelClosed { .. } => break,
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut manager_channel = self.client.clone_channel();
|
||||
let forward_task = tokio::spawn(async move {
|
||||
use distant_net::TypedAsyncRead;
|
||||
loop {
|
||||
match reader.read().await {
|
||||
Ok(Some(request)) => {
|
||||
// NOTE: In this situation, we do not expect a response to this
|
||||
// request (even if the server sends something back)
|
||||
if let Err(x) = manager_channel
|
||||
.fire(ManagerRequest::Channel {
|
||||
id: channel_id,
|
||||
request,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(RawDistantChannel {
|
||||
transport: t2,
|
||||
forward_task,
|
||||
mailbox_task,
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieves a list of supported capabilities
|
||||
pub async fn capabilities(&mut self) -> io::Result<ManagerCapabilities> {
|
||||
trace!("capabilities()");
|
||||
let res = self.client.send(ManagerRequest::Capabilities).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Capabilities { supported } => Ok(supported),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves information about a specific connection
|
||||
pub async fn info(&mut self, id: ConnectionId) -> io::Result<ConnectionInfo> {
|
||||
trace!("info({})", id);
|
||||
let res = self.client.send(ManagerRequest::Info { id }).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Info(info) => Ok(info),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Kills the specified connection
|
||||
pub async fn kill(&mut self, id: ConnectionId) -> io::Result<()> {
|
||||
trace!("kill({})", id);
|
||||
let res = self.client.send(ManagerRequest::Kill { id }).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Killed => Ok(()),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a list of active connections
|
||||
pub async fn list(&mut self) -> io::Result<ConnectionList> {
|
||||
trace!("list()");
|
||||
let res = self.client.send(ManagerRequest::List).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::List(list) => Ok(list),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Requests that the manager shuts down
|
||||
pub async fn shutdown(&mut self) -> io::Result<()> {
|
||||
trace!("shutdown()");
|
||||
let res = self.client.send(ManagerRequest::Shutdown).await?;
|
||||
match res.payload {
|
||||
ManagerResponse::Shutdown => Ok(()),
|
||||
ManagerResponse::Error(x) => Err(x.into()),
|
||||
x => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Got unexpected response: {:?}", x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::data::{Error, ErrorKind};
|
||||
use distant_net::{
|
||||
FramedTransport, InmemoryTransport, PlainCodec, UntypedTransportRead, UntypedTransportWrite,
|
||||
};
|
||||
|
||||
fn setup() -> (
|
||||
DistantManagerClient,
|
||||
FramedTransport<InmemoryTransport, PlainCodec>,
|
||||
) {
|
||||
let (t1, t2) = FramedTransport::pair(100);
|
||||
let client =
|
||||
DistantManagerClient::new(DistantManagerClientConfig::with_empty_prompts(), t1)
|
||||
.unwrap();
|
||||
(client, t2)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn test_error() -> Error {
|
||||
Error {
|
||||
kind: ErrorKind::Interrupted,
|
||||
description: "test error".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn test_io_error() -> io::Error {
|
||||
test_error().into()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Map>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Map>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_return_id_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
let expected_id = 999;
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Connected { id: expected_id },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let id = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Map>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(id, expected_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.info(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.info(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_return_connection_info_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let info = ConnectionInfo {
|
||||
id: 123,
|
||||
destination: "scheme://host".parse::<Destination>().unwrap(),
|
||||
options: "key=value".parse::<Map>().unwrap(),
|
||||
};
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Info(info)))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let info = client.info(123).await.unwrap();
|
||||
assert_eq!(info.id, 123);
|
||||
assert_eq!(
|
||||
info.destination,
|
||||
"scheme://host".parse::<Destination>().unwrap()
|
||||
);
|
||||
assert_eq!(info.options, "key=value".parse::<Map>().unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.list().await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.list().await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_return_connection_list_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut list = ConnectionList::new();
|
||||
list.insert(123, "scheme://host".parse::<Destination>().unwrap());
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::List(list)))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let list = client.list().await.unwrap();
|
||||
assert_eq!(list.len(), 1);
|
||||
assert_eq!(
|
||||
list.get(&123).expect("Connection list missing item"),
|
||||
&"scheme://host".parse::<Destination>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.kill(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.kill(123).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_return_success_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Killed))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
client.kill(123).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_should_report_error_if_receives_error_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Connected { id: 0 },
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.shutdown().await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_should_report_error_if_receives_unexpected_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
ManagerResponse::Error(test_error()),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let err = client.shutdown().await.unwrap_err();
|
||||
assert_eq!(err.kind(), test_io_error().kind());
|
||||
assert_eq!(err.to_string(), test_io_error().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_should_return_success_from_successful_response() {
|
||||
let (mut client, mut transport) = setup();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let request = transport
|
||||
.read::<Request<ManagerRequest>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.write(Response::new(request.id, ManagerResponse::Shutdown))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
client.shutdown().await.unwrap();
|
||||
}
|
||||
}
|
@ -1,85 +0,0 @@
|
||||
use distant_net::{AuthChallengeFn, AuthErrorFn, AuthInfoFn, AuthVerifyFn, AuthVerifyKind};
|
||||
use log::*;
|
||||
use std::io;
|
||||
|
||||
/// Configuration to use when creating a new [`DistantManagerClient`](super::DistantManagerClient)
|
||||
pub struct DistantManagerClientConfig {
|
||||
pub on_challenge: Box<AuthChallengeFn>,
|
||||
pub on_verify: Box<AuthVerifyFn>,
|
||||
pub on_info: Box<AuthInfoFn>,
|
||||
pub on_error: Box<AuthErrorFn>,
|
||||
}
|
||||
|
||||
impl DistantManagerClientConfig {
|
||||
/// Creates a new config with prompts that return empty strings
|
||||
pub fn with_empty_prompts() -> Self {
|
||||
Self::with_prompts(|_| Ok("".to_string()), |_| Ok("".to_string()))
|
||||
}
|
||||
|
||||
/// Creates a new config with two prompts
|
||||
///
|
||||
/// * `password_prompt` - used for prompting for a secret, and should not display what is typed
|
||||
/// * `text_prompt` - used for general text, and is okay to display what is typed
|
||||
pub fn with_prompts<PP, PT>(password_prompt: PP, text_prompt: PT) -> Self
|
||||
where
|
||||
PP: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
PT: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
on_challenge: Box::new(move |questions, _extra| {
|
||||
trace!("[manager client] on_challenge({questions:?}, {_extra:?})");
|
||||
let mut answers = Vec::new();
|
||||
for question in questions.iter() {
|
||||
// Contains all prompt lines including same line
|
||||
let mut lines = question.text.split('\n').collect::<Vec<_>>();
|
||||
|
||||
// Line that is prompt on same line as answer
|
||||
let line = lines.pop().unwrap();
|
||||
|
||||
// Go ahead and display all other lines
|
||||
for line in lines.into_iter() {
|
||||
eprintln!("{}", line);
|
||||
}
|
||||
|
||||
// Get an answer from user input, or use a blank string as an answer
|
||||
// if we fail to get input from the user
|
||||
let answer = password_prompt(line).unwrap_or_default();
|
||||
|
||||
answers.push(answer);
|
||||
}
|
||||
answers
|
||||
}),
|
||||
on_verify: Box::new(move |kind, text| {
|
||||
trace!("[manager client] on_verify({kind}, {text})");
|
||||
match kind {
|
||||
AuthVerifyKind::Host => {
|
||||
eprintln!("{}", text);
|
||||
|
||||
match text_prompt("Enter [y/N]> ") {
|
||||
Ok(answer) => {
|
||||
trace!("Verify? Answer = '{answer}'");
|
||||
matches!(answer.trim(), "y" | "Y" | "yes" | "YES")
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Failed verification: {x}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
x => {
|
||||
error!("Unsupported verify kind: {x}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}),
|
||||
on_info: Box::new(|text| {
|
||||
trace!("[manager client] on_info({text})");
|
||||
println!("{}", text);
|
||||
}),
|
||||
on_error: Box::new(|kind, text| {
|
||||
trace!("[manager client] on_error({kind}, {text})");
|
||||
eprintln!("{}: {}", kind, text);
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,14 +0,0 @@
|
||||
mod tcp;
|
||||
pub use tcp::*;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
|
||||
#[cfg(windows)]
|
||||
mod windows;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub use windows::*;
|
@ -1,50 +0,0 @@
|
||||
use crate::{DistantManagerClient, DistantManagerClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Codec, FramedTransport, TcpTransport};
|
||||
use std::{convert, net::SocketAddr};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait TcpDistantManagerClientExt {
|
||||
/// Connect to a remote TCP server using the provided information
|
||||
async fn connect<C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a remote TCP server, timing out after duration has passed
|
||||
async fn connect_timeout<C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(config, addr, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TcpDistantManagerClientExt for DistantManagerClient {
|
||||
/// Connect to a remote TCP server using the provided information
|
||||
async fn connect<C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let transport = TcpTransport::connect(addr).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Self::new(config, transport)
|
||||
}
|
||||
}
|
@ -1,54 +0,0 @@
|
||||
use crate::{DistantManagerClient, DistantManagerClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Codec, FramedTransport, UnixSocketTransport};
|
||||
use std::{convert, path::Path};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UnixSocketDistantManagerClientExt {
|
||||
/// Connect to a proxy unix socket
|
||||
async fn connect<P, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a proxy unix socket, timing out after duration has passed
|
||||
async fn connect_timeout<P, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(config, path, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl UnixSocketDistantManagerClientExt for DistantManagerClient {
|
||||
/// Connect to a proxy unix socket
|
||||
async fn connect<P, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let p = path.as_ref();
|
||||
let transport = UnixSocketTransport::connect(p).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Ok(DistantManagerClient::new(config, transport)?)
|
||||
}
|
||||
}
|
@ -1,91 +0,0 @@
|
||||
use crate::{DistantManagerClient, DistantManagerClientConfig};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{Codec, FramedTransport, WindowsPipeTransport};
|
||||
use std::{
|
||||
convert,
|
||||
ffi::{OsStr, OsString},
|
||||
};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait WindowsPipeDistantManagerClientExt {
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// using the given codec
|
||||
async fn connect<A, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// via `\\.\pipe\{name}` using the given codec
|
||||
async fn connect_local<N, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
name: N,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::connect(config, addr, codec).await
|
||||
}
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// using the given codec, timing out after duration has passed
|
||||
async fn connect_timeout<A, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(config, addr, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed
|
||||
async fn connect_local_timeout<N, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
name: N,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::connect_timeout(config, addr, codec, duration).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WindowsPipeDistantManagerClientExt for DistantManagerClient {
|
||||
async fn connect<A, C>(
|
||||
config: DistantManagerClientConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
) -> io::Result<DistantManagerClient>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let a = addr.as_ref();
|
||||
let transport = WindowsPipeTransport::connect(a).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Ok(DistantManagerClient::new(config, transport)?)
|
||||
}
|
||||
}
|
@ -1,5 +0,0 @@
|
||||
/// Id associated with an active connection
|
||||
pub type ConnectionId = u64;
|
||||
|
||||
/// Id associated with an open channel
|
||||
pub type ChannelId = u64;
|
@ -1,719 +0,0 @@
|
||||
use crate::{
|
||||
ChannelId, ConnectionId, ConnectionInfo, ConnectionList, Destination, ManagerCapabilities,
|
||||
ManagerRequest, ManagerResponse, Map,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{
|
||||
router, Auth, AuthClient, Client, IntoSplit, Listener, MpscListener, Request, Response, Server,
|
||||
ServerCtx, ServerExt, UntypedTransportRead, UntypedTransportWrite,
|
||||
};
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io, sync::Arc};
|
||||
use tokio::{
|
||||
sync::{mpsc, Mutex, RwLock},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
mod config;
|
||||
pub use config::*;
|
||||
|
||||
mod connection;
|
||||
pub use connection::*;
|
||||
|
||||
mod ext;
|
||||
pub use ext::*;
|
||||
|
||||
mod handler;
|
||||
pub use handler::*;
|
||||
|
||||
mod r#ref;
|
||||
pub use r#ref::*;
|
||||
|
||||
router!(DistantManagerRouter {
|
||||
auth_transport: Response<Auth> => Request<Auth>,
|
||||
manager_transport: Request<ManagerRequest> => Response<ManagerResponse>,
|
||||
});
|
||||
|
||||
/// Represents a manager of multiple distant server connections
|
||||
pub struct DistantManager {
|
||||
/// Receives authentication clients to feed into local data of server
|
||||
auth_client_rx: Mutex<mpsc::Receiver<AuthClient>>,
|
||||
|
||||
/// Configuration settings for the server
|
||||
config: DistantManagerConfig,
|
||||
|
||||
/// Mapping of connection id -> connection
|
||||
connections: RwLock<HashMap<ConnectionId, DistantManagerConnection>>,
|
||||
|
||||
/// Handlers for launch requests
|
||||
launch_handlers: Arc<RwLock<HashMap<String, BoxedLaunchHandler>>>,
|
||||
|
||||
/// Handlers for connect requests
|
||||
connect_handlers: Arc<RwLock<HashMap<String, BoxedConnectHandler>>>,
|
||||
|
||||
/// Primary task of server
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl DistantManager {
|
||||
/// Initializes a new instance of [`DistantManagerServer`] using the provided [`UntypedTransport`]
|
||||
pub fn start<L, T>(
|
||||
mut config: DistantManagerConfig,
|
||||
mut listener: L,
|
||||
) -> io::Result<DistantManagerRef>
|
||||
where
|
||||
L: Listener<Output = T> + 'static,
|
||||
T: IntoSplit + Send + 'static,
|
||||
T::Read: UntypedTransportRead + 'static,
|
||||
T::Write: UntypedTransportWrite + 'static,
|
||||
{
|
||||
let (conn_tx, mpsc_listener) = MpscListener::channel(config.connection_buffer_size);
|
||||
let (auth_client_tx, auth_client_rx) = mpsc::channel(1);
|
||||
|
||||
// Spawn task that uses our input listener to get both auth and manager events,
|
||||
// forwarding manager events to the internal mpsc listener
|
||||
let task = tokio::spawn(async move {
|
||||
while let Ok(transport) = listener.accept().await {
|
||||
let DistantManagerRouter {
|
||||
auth_transport,
|
||||
manager_transport,
|
||||
..
|
||||
} = DistantManagerRouter::new(transport);
|
||||
|
||||
let (writer, reader) = auth_transport.into_split();
|
||||
let client = match Client::new(writer, reader) {
|
||||
Ok(client) => client,
|
||||
Err(x) => {
|
||||
error!("Creating auth client failed: {}", x);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let auth_client = AuthClient::from(client);
|
||||
|
||||
// Forward auth client for new connection in server
|
||||
if auth_client_tx.send(auth_client).await.is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Forward connected and routed transport to server
|
||||
if conn_tx.send(manager_transport.into_split()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let launch_handlers = Arc::new(RwLock::new(config.launch_handlers.drain().collect()));
|
||||
let weak_launch_handlers = Arc::downgrade(&launch_handlers);
|
||||
let connect_handlers = Arc::new(RwLock::new(config.connect_handlers.drain().collect()));
|
||||
let weak_connect_handlers = Arc::downgrade(&connect_handlers);
|
||||
let server_ref = Self {
|
||||
auth_client_rx: Mutex::new(auth_client_rx),
|
||||
config,
|
||||
launch_handlers,
|
||||
connect_handlers,
|
||||
connections: RwLock::new(HashMap::new()),
|
||||
task,
|
||||
}
|
||||
.start(mpsc_listener)?;
|
||||
|
||||
Ok(DistantManagerRef {
|
||||
launch_handlers: weak_launch_handlers,
|
||||
connect_handlers: weak_connect_handlers,
|
||||
inner: server_ref,
|
||||
})
|
||||
}
|
||||
|
||||
/// Launches a new server at the specified `destination` using the given `options` information
|
||||
/// and authentication client (if needed) to retrieve additional information needed to
|
||||
/// enter the destination prior to starting the server, returning the destination of the
|
||||
/// launched server
|
||||
async fn launch(
|
||||
&self,
|
||||
destination: Destination,
|
||||
options: Map,
|
||||
auth: Option<&mut AuthClient>,
|
||||
) -> io::Result<Destination> {
|
||||
let auth = auth.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Authentication client not initialized",
|
||||
)
|
||||
})?;
|
||||
|
||||
let scheme = match destination.scheme.as_deref() {
|
||||
Some(scheme) => {
|
||||
trace!("Using scheme {}", scheme);
|
||||
scheme
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"Using fallback scheme of {}",
|
||||
self.config.launch_fallback_scheme.as_str()
|
||||
);
|
||||
self.config.launch_fallback_scheme.as_str()
|
||||
}
|
||||
}
|
||||
.to_lowercase();
|
||||
|
||||
let credentials = {
|
||||
let lock = self.launch_handlers.read().await;
|
||||
let handler = lock.get(&scheme).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("No launch handler registered for {}", scheme),
|
||||
)
|
||||
})?;
|
||||
handler.launch(&destination, &options, auth).await?
|
||||
};
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
/// Connects to a new server at the specified `destination` using the given `options` information
|
||||
/// and authentication client (if needed) to retrieve additional information needed to
|
||||
/// establish the connection to the server
|
||||
async fn connect(
|
||||
&self,
|
||||
destination: Destination,
|
||||
options: Map,
|
||||
auth: Option<&mut AuthClient>,
|
||||
) -> io::Result<ConnectionId> {
|
||||
let auth = auth.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Authentication client not initialized",
|
||||
)
|
||||
})?;
|
||||
|
||||
let scheme = match destination.scheme.as_deref() {
|
||||
Some(scheme) => {
|
||||
trace!("Using scheme {}", scheme);
|
||||
scheme
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"Using fallback scheme of {}",
|
||||
self.config.connect_fallback_scheme.as_str()
|
||||
);
|
||||
self.config.connect_fallback_scheme.as_str()
|
||||
}
|
||||
}
|
||||
.to_lowercase();
|
||||
|
||||
let (writer, reader) = {
|
||||
let lock = self.connect_handlers.read().await;
|
||||
let handler = lock.get(&scheme).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("No connect handler registered for {}", scheme),
|
||||
)
|
||||
})?;
|
||||
handler.connect(&destination, &options, auth).await?
|
||||
};
|
||||
|
||||
let connection = DistantManagerConnection::new(destination, options, writer, reader);
|
||||
let id = connection.id;
|
||||
self.connections.write().await.insert(id, connection);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Retrieves the list of supported capabilities for this manager
|
||||
async fn capabilities(&self) -> io::Result<ManagerCapabilities> {
|
||||
Ok(ManagerCapabilities::all())
|
||||
}
|
||||
|
||||
/// Retrieves information about the connection to the server with the specified `id`
|
||||
async fn info(&self, id: ConnectionId) -> io::Result<ConnectionInfo> {
|
||||
match self.connections.read().await.get(&id) {
|
||||
Some(connection) => Ok(ConnectionInfo {
|
||||
id: connection.id,
|
||||
destination: connection.destination.clone(),
|
||||
options: connection.options.clone(),
|
||||
}),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"No connection found",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a list of connections to servers
|
||||
async fn list(&self) -> io::Result<ConnectionList> {
|
||||
Ok(ConnectionList(
|
||||
self.connections
|
||||
.read()
|
||||
.await
|
||||
.values()
|
||||
.map(|conn| (conn.id, conn.destination.clone()))
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Kills the connection to the server with the specified `id`
|
||||
async fn kill(&self, id: ConnectionId) -> io::Result<()> {
|
||||
match self.connections.write().await.remove(&id) {
|
||||
Some(_) => Ok(()),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"No connection found",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DistantManagerServerConnection {
|
||||
/// Authentication client that manager can use when establishing a new connection
|
||||
/// and needing to get authentication details from the client to move forward
|
||||
auth_client: Option<Mutex<AuthClient>>,
|
||||
|
||||
/// Holds on to open channels feeding data back from a server to some connected client,
|
||||
/// enabling us to cancel the tasks on demand
|
||||
channels: RwLock<HashMap<ChannelId, DistantManagerChannel>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Server for DistantManager {
|
||||
type Request = ManagerRequest;
|
||||
type Response = ManagerResponse;
|
||||
type LocalData = DistantManagerServerConnection;
|
||||
|
||||
async fn on_accept(&self, local_data: &mut Self::LocalData) {
|
||||
local_data.auth_client = self
|
||||
.auth_client_rx
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
.await
|
||||
.map(Mutex::new);
|
||||
|
||||
// Enable jit handshake
|
||||
if let Some(auth_client) = local_data.auth_client.as_ref() {
|
||||
auth_client.lock().await.set_jit_handshake(true);
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
|
||||
let ServerCtx {
|
||||
connection_id,
|
||||
request,
|
||||
reply,
|
||||
local_data,
|
||||
} = ctx;
|
||||
|
||||
let response = match request.payload {
|
||||
ManagerRequest::Capabilities {} => match self.capabilities().await {
|
||||
Ok(supported) => ManagerResponse::Capabilities { supported },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::Launch {
|
||||
destination,
|
||||
options,
|
||||
} => {
|
||||
let mut auth = match local_data.auth_client.as_ref() {
|
||||
Some(client) => Some(client.lock().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
match self
|
||||
.launch(*destination, options, auth.as_deref_mut())
|
||||
.await
|
||||
{
|
||||
Ok(destination) => ManagerResponse::Launched { destination },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
}
|
||||
}
|
||||
ManagerRequest::Connect {
|
||||
destination,
|
||||
options,
|
||||
} => {
|
||||
let mut auth = match local_data.auth_client.as_ref() {
|
||||
Some(client) => Some(client.lock().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
match self
|
||||
.connect(*destination, options, auth.as_deref_mut())
|
||||
.await
|
||||
{
|
||||
Ok(id) => ManagerResponse::Connected { id },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
}
|
||||
}
|
||||
ManagerRequest::OpenChannel { id } => match self.connections.read().await.get(&id) {
|
||||
Some(connection) => match connection.open_channel(reply.clone()).await {
|
||||
Ok(channel) => {
|
||||
let id = channel.id();
|
||||
local_data.channels.write().await.insert(id, channel);
|
||||
ManagerResponse::ChannelOpened { id }
|
||||
}
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
None => ManagerResponse::Error(
|
||||
io::Error::new(io::ErrorKind::NotConnected, "Connection does not exist").into(),
|
||||
),
|
||||
},
|
||||
ManagerRequest::Channel { id, request } => {
|
||||
match local_data.channels.read().await.get(&id) {
|
||||
// TODO: For now, we are NOT sending back a response to acknowledge
|
||||
// a successful channel send. We could do this in order for
|
||||
// the client to listen for a complete send, but is it worth it?
|
||||
Some(channel) => match channel.send(request).await {
|
||||
Ok(_) => return,
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
None => ManagerResponse::Error(
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"Channel is not open or does not exist",
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
ManagerRequest::CloseChannel { id } => {
|
||||
match local_data.channels.write().await.remove(&id) {
|
||||
Some(channel) => match channel.close().await {
|
||||
Ok(_) => ManagerResponse::ChannelClosed { id },
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
None => ManagerResponse::Error(
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"Channel is not open or does not exist",
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
ManagerRequest::Info { id } => match self.info(id).await {
|
||||
Ok(info) => ManagerResponse::Info(info),
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::List => match self.list().await {
|
||||
Ok(list) => ManagerResponse::List(list),
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::Kill { id } => match self.kill(id).await {
|
||||
Ok(()) => ManagerResponse::Killed,
|
||||
Err(x) => ManagerResponse::Error(x.into()),
|
||||
},
|
||||
ManagerRequest::Shutdown => {
|
||||
if let Err(x) = reply.send(ManagerResponse::Shutdown).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
|
||||
// Clear out handler state in order to trigger drops
|
||||
self.launch_handlers.write().await.clear();
|
||||
self.connect_handlers.write().await.clear();
|
||||
|
||||
// Shutdown the primary server task
|
||||
self.task.abort();
|
||||
|
||||
// TODO: Perform a graceful shutdown instead of this?
|
||||
// Review https://tokio.rs/tokio/topics/shutdown
|
||||
std::process::exit(0);
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(x) = reply.send(response).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use distant_net::{
|
||||
AuthClient, FramedTransport, HeapAuthServer, InmemoryTransport, IntoSplit, MappedListener,
|
||||
OneshotListener, PlainCodec, ServerExt, ServerRef,
|
||||
};
|
||||
|
||||
/// Create a new server, bypassing the start loop
|
||||
fn setup() -> DistantManager {
|
||||
let (_, rx) = mpsc::channel(1);
|
||||
DistantManager {
|
||||
auth_client_rx: Mutex::new(rx),
|
||||
config: Default::default(),
|
||||
connections: RwLock::new(HashMap::new()),
|
||||
launch_handlers: Arc::new(RwLock::new(HashMap::new())),
|
||||
connect_handlers: Arc::new(RwLock::new(HashMap::new())),
|
||||
task: tokio::spawn(async move {}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a connected [`AuthClient`] with a launched auth server that blindly responds
|
||||
fn auth_client_server() -> (AuthClient, Box<dyn ServerRef>) {
|
||||
let (t1, t2) = FramedTransport::pair(1);
|
||||
let client = AuthClient::from(Client::from_framed_transport(t1).unwrap());
|
||||
|
||||
// Create a server that does nothing, but will support
|
||||
let server = HeapAuthServer {
|
||||
on_challenge: Box::new(|_, _| Vec::new()),
|
||||
on_verify: Box::new(|_, _| false),
|
||||
on_info: Box::new(|_| ()),
|
||||
on_error: Box::new(|_, _| ()),
|
||||
}
|
||||
.start(MappedListener::new(OneshotListener::from_value(t2), |t| {
|
||||
t.into_split()
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
(client, server)
|
||||
}
|
||||
|
||||
fn dummy_distant_writer_reader() -> (BoxedDistantWriter, BoxedDistantReader) {
|
||||
setup_distant_writer_reader().0
|
||||
}
|
||||
|
||||
/// Creates a writer & reader with a connected transport
|
||||
fn setup_distant_writer_reader() -> (
|
||||
(BoxedDistantWriter, BoxedDistantReader),
|
||||
FramedTransport<InmemoryTransport, PlainCodec>,
|
||||
) {
|
||||
let (t1, t2) = FramedTransport::pair(1);
|
||||
let (writer, reader) = t1.into_split();
|
||||
((Box::new(writer), Box::new(reader)), t2)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn launch_should_fail_if_destination_scheme_is_unsupported() {
|
||||
let server = setup();
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let options = "".parse::<Map>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.launch(destination, options, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn launch_should_fail_if_handler_tied_to_scheme_fails() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn LaunchHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
|
||||
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
|
||||
});
|
||||
|
||||
server
|
||||
.launch_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let options = "".parse::<Map>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.launch(destination, options, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::Other);
|
||||
assert_eq!(err.to_string(), "test failure");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn launch_should_return_new_destination_on_success() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn LaunchHandler> = {
|
||||
Box::new(|_: &_, _: &_, _: &mut _| async {
|
||||
Ok("scheme2://host2".parse::<Destination>().unwrap())
|
||||
})
|
||||
};
|
||||
|
||||
server
|
||||
.launch_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let options = "key=value".parse::<Map>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let destination = server
|
||||
.launch(destination, options, Some(&mut auth))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
destination,
|
||||
"scheme2://host2".parse::<Destination>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_fail_if_destination_scheme_is_unsupported() {
|
||||
let server = setup();
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let options = "".parse::<Map>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.connect(destination, options, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidInput, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_fail_if_handler_tied_to_scheme_fails() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn ConnectHandler> = Box::new(|_: &_, _: &_, _: &mut _| async {
|
||||
Err(io::Error::new(io::ErrorKind::Other, "test failure"))
|
||||
});
|
||||
|
||||
server
|
||||
.connect_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let options = "".parse::<Map>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let err = server
|
||||
.connect(destination, options, Some(&mut auth))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::Other);
|
||||
assert_eq!(err.to_string(), "test failure");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_should_return_id_of_new_connection_on_success() {
|
||||
let server = setup();
|
||||
|
||||
let handler: Box<dyn ConnectHandler> =
|
||||
Box::new(|_: &_, _: &_, _: &mut _| async { Ok(dummy_distant_writer_reader()) });
|
||||
|
||||
server
|
||||
.connect_handlers
|
||||
.write()
|
||||
.await
|
||||
.insert("scheme".to_string(), handler);
|
||||
|
||||
let destination = "scheme://host".parse::<Destination>().unwrap();
|
||||
let options = "key=value".parse::<Map>().unwrap();
|
||||
let (mut auth, _auth_server) = auth_client_server();
|
||||
let id = server
|
||||
.connect(destination, options, Some(&mut auth))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let lock = server.connections.read().await;
|
||||
let connection = lock.get(&id).unwrap();
|
||||
assert_eq!(connection.id, id);
|
||||
assert_eq!(connection.destination, "scheme://host");
|
||||
assert_eq!(connection.options, "key=value".parse().unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_fail_if_no_connection_found_for_specified_id() {
|
||||
let server = setup();
|
||||
|
||||
let err = server.info(999).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_return_information_about_established_connection() {
|
||||
let server = setup();
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"scheme://host".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id = connection.id;
|
||||
server.connections.write().await.insert(id, connection);
|
||||
|
||||
let info = server.info(id).await.unwrap();
|
||||
assert_eq!(
|
||||
info,
|
||||
ConnectionInfo {
|
||||
id,
|
||||
destination: "scheme://host".parse().unwrap(),
|
||||
options: "key=value".parse().unwrap(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_return_empty_connection_list_if_no_established_connections() {
|
||||
let server = setup();
|
||||
|
||||
let list = server.list().await.unwrap();
|
||||
assert_eq!(list, ConnectionList(HashMap::new()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_should_return_a_list_of_established_connections() {
|
||||
let server = setup();
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"scheme://host".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id_1 = connection.id;
|
||||
server.connections.write().await.insert(id_1, connection);
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"other://host2".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id_2 = connection.id;
|
||||
server.connections.write().await.insert(id_2, connection);
|
||||
|
||||
let list = server.list().await.unwrap();
|
||||
assert_eq!(
|
||||
list.get(&id_1).unwrap(),
|
||||
&"scheme://host".parse::<Destination>().unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
list.get(&id_2).unwrap(),
|
||||
&"other://host2".parse::<Destination>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_fail_if_no_connection_found_for_specified_id() {
|
||||
let server = setup();
|
||||
|
||||
let err = server.kill(999).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::NotConnected, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn kill_should_terminate_established_connection_and_remove_it_from_the_list() {
|
||||
let server = setup();
|
||||
|
||||
let (writer, reader) = dummy_distant_writer_reader();
|
||||
let connection = DistantManagerConnection::new(
|
||||
"scheme://host".parse().unwrap(),
|
||||
"key=value".parse().unwrap(),
|
||||
writer,
|
||||
reader,
|
||||
);
|
||||
let id = connection.id;
|
||||
server.connections.write().await.insert(id, connection);
|
||||
|
||||
server.kill(id).await.unwrap();
|
||||
|
||||
let lock = server.connections.read().await;
|
||||
assert!(!lock.contains_key(&id), "Connection still exists");
|
||||
}
|
||||
}
|
@ -1,202 +0,0 @@
|
||||
use crate::{
|
||||
data::Map,
|
||||
manager::{
|
||||
data::{ChannelId, ConnectionId, Destination},
|
||||
BoxedDistantReader, BoxedDistantWriter,
|
||||
},
|
||||
DistantMsg, DistantRequestData, DistantResponseData, ManagerResponse,
|
||||
};
|
||||
use distant_net::{Request, Response, ServerReply};
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
use tokio::{sync::mpsc, task::JoinHandle};
|
||||
|
||||
/// Represents a connection a distant manager has with some distant-compatible server
|
||||
pub struct DistantManagerConnection {
|
||||
pub id: ConnectionId,
|
||||
pub destination: Destination,
|
||||
pub options: Map,
|
||||
tx: mpsc::Sender<StateMachine>,
|
||||
reader_task: JoinHandle<()>,
|
||||
writer_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DistantManagerChannel {
|
||||
channel_id: ChannelId,
|
||||
tx: mpsc::Sender<StateMachine>,
|
||||
}
|
||||
|
||||
impl DistantManagerChannel {
|
||||
pub fn id(&self) -> ChannelId {
|
||||
self.channel_id
|
||||
}
|
||||
|
||||
pub async fn send(&self, request: Request<DistantMsg<DistantRequestData>>) -> io::Result<()> {
|
||||
let channel_id = self.channel_id;
|
||||
self.tx
|
||||
.send(StateMachine::Write {
|
||||
id: channel_id,
|
||||
request,
|
||||
})
|
||||
.await
|
||||
.map_err(|x| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("channel {} send failed: {}", channel_id, x),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> io::Result<()> {
|
||||
let channel_id = self.channel_id;
|
||||
self.tx
|
||||
.send(StateMachine::Unregister { id: channel_id })
|
||||
.await
|
||||
.map_err(|x| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("channel {} close failed: {}", channel_id, x),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum StateMachine {
|
||||
Register {
|
||||
id: ChannelId,
|
||||
reply: ServerReply<ManagerResponse>,
|
||||
},
|
||||
|
||||
Unregister {
|
||||
id: ChannelId,
|
||||
},
|
||||
|
||||
Read {
|
||||
response: Response<DistantMsg<DistantResponseData>>,
|
||||
},
|
||||
|
||||
Write {
|
||||
id: ChannelId,
|
||||
request: Request<DistantMsg<DistantRequestData>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl DistantManagerConnection {
|
||||
pub fn new(
|
||||
destination: Destination,
|
||||
options: Map,
|
||||
mut writer: BoxedDistantWriter,
|
||||
mut reader: BoxedDistantReader,
|
||||
) -> Self {
|
||||
let connection_id = rand::random();
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let reader_task = {
|
||||
let tx = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match reader.read().await {
|
||||
Ok(Some(response)) => {
|
||||
if tx.send(StateMachine::Read { response }).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut registered = HashMap::new();
|
||||
while let Some(state_machine) = rx.recv().await {
|
||||
match state_machine {
|
||||
StateMachine::Register { id, reply } => {
|
||||
registered.insert(id, reply);
|
||||
}
|
||||
StateMachine::Unregister { id } => {
|
||||
registered.remove(&id);
|
||||
}
|
||||
StateMachine::Read { mut response } => {
|
||||
// Split {channel id}_{request id} back into pieces and
|
||||
// update the origin id to match the request id only
|
||||
let channel_id = match response.origin_id.split_once('_') {
|
||||
Some((cid_str, oid_str)) => {
|
||||
if let Ok(cid) = cid_str.parse::<ChannelId>() {
|
||||
response.origin_id = oid_str.to_string();
|
||||
cid
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if let Some(reply) = registered.get(&channel_id) {
|
||||
let response = ManagerResponse::Channel {
|
||||
id: channel_id,
|
||||
response,
|
||||
};
|
||||
if let Err(x) = reply.send(response).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
StateMachine::Write { id, request } => {
|
||||
// Combine channel id with request id so we can properly forward
|
||||
// the response containing this in the origin id
|
||||
let request = Request {
|
||||
id: format!("{}_{}", id, request.id),
|
||||
payload: request.payload,
|
||||
};
|
||||
if let Err(x) = writer.write(request).await {
|
||||
error!("[Conn {}] {}", connection_id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
id: connection_id,
|
||||
destination,
|
||||
options,
|
||||
tx,
|
||||
reader_task,
|
||||
writer_task,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn open_channel(
|
||||
&self,
|
||||
reply: ServerReply<ManagerResponse>,
|
||||
) -> io::Result<DistantManagerChannel> {
|
||||
let channel_id = rand::random();
|
||||
self.tx
|
||||
.send(StateMachine::Register {
|
||||
id: channel_id,
|
||||
reply,
|
||||
})
|
||||
.await
|
||||
.map_err(|x| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
format!("open_channel failed: {}", x),
|
||||
)
|
||||
})?;
|
||||
Ok(DistantManagerChannel {
|
||||
channel_id,
|
||||
tx: self.tx.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DistantManagerConnection {
|
||||
fn drop(&mut self) {
|
||||
self.reader_task.abort();
|
||||
self.writer_task.abort();
|
||||
}
|
||||
}
|
@ -1,14 +0,0 @@
|
||||
mod tcp;
|
||||
pub use tcp::*;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
|
||||
#[cfg(windows)]
|
||||
mod windows;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub use windows::*;
|
@ -1,30 +0,0 @@
|
||||
use crate::{DistantManager, DistantManagerConfig};
|
||||
use distant_net::{
|
||||
Codec, FramedTransport, IntoSplit, MappedListener, PortRange, TcpListener, TcpServerRef,
|
||||
};
|
||||
use std::{io, net::IpAddr};
|
||||
|
||||
impl DistantManager {
|
||||
/// Start a new server by binding to the given IP address and one of the ports in the
|
||||
/// specified range, mapping all connections to use the given codec
|
||||
pub async fn start_tcp<P, C>(
|
||||
config: DistantManagerConfig,
|
||||
addr: IpAddr,
|
||||
port: P,
|
||||
codec: C,
|
||||
) -> io::Result<TcpServerRef>
|
||||
where
|
||||
P: Into<PortRange> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let listener = TcpListener::bind(addr, port).await?;
|
||||
let port = listener.port();
|
||||
|
||||
let listener = MappedListener::new(listener, move |transport| {
|
||||
let transport = FramedTransport::new(transport, codec.clone());
|
||||
transport.into_split()
|
||||
});
|
||||
let inner = DistantManager::start(config, listener)?;
|
||||
Ok(TcpServerRef::new(addr, port, Box::new(inner)))
|
||||
}
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
use crate::{DistantManager, DistantManagerConfig};
|
||||
use distant_net::{
|
||||
Codec, FramedTransport, IntoSplit, MappedListener, UnixSocketListener, UnixSocketServerRef,
|
||||
};
|
||||
use std::{io, path::Path};
|
||||
|
||||
impl DistantManager {
|
||||
/// Start a new server using the specified path as a unix socket using default unix socket file
|
||||
/// permissions
|
||||
pub async fn start_unix_socket<P, C>(
|
||||
config: DistantManagerConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
) -> io::Result<UnixSocketServerRef>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
Self::start_unix_socket_with_permissions(
|
||||
config,
|
||||
path,
|
||||
codec,
|
||||
UnixSocketListener::default_unix_socket_file_permissions(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Start a new server using the specified path as a unix socket and `mode` as the unix socket
|
||||
/// file permissions
|
||||
pub async fn start_unix_socket_with_permissions<P, C>(
|
||||
config: DistantManagerConfig,
|
||||
path: P,
|
||||
codec: C,
|
||||
mode: u32,
|
||||
) -> io::Result<UnixSocketServerRef>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let listener = UnixSocketListener::bind_with_permissions(path, mode).await?;
|
||||
let path = listener.path().to_path_buf();
|
||||
|
||||
let listener = MappedListener::new(listener, move |transport| {
|
||||
let transport = FramedTransport::new(transport, codec.clone());
|
||||
transport.into_split()
|
||||
});
|
||||
let inner = DistantManager::start(config, listener)?;
|
||||
Ok(UnixSocketServerRef::new(path, Box::new(inner)))
|
||||
}
|
||||
}
|
@ -1,48 +0,0 @@
|
||||
use crate::{DistantManager, DistantManagerConfig};
|
||||
use distant_net::{
|
||||
Codec, FramedTransport, IntoSplit, MappedListener, WindowsPipeListener, WindowsPipeServerRef,
|
||||
};
|
||||
use std::{
|
||||
ffi::{OsStr, OsString},
|
||||
io,
|
||||
};
|
||||
|
||||
impl DistantManager {
|
||||
/// Start a new server at the specified address via `\\.\pipe\{name}` using the given codec
|
||||
pub async fn start_local_named_pipe<N, C>(
|
||||
config: DistantManagerConfig,
|
||||
name: N,
|
||||
codec: C,
|
||||
) -> io::Result<WindowsPipeServerRef>
|
||||
where
|
||||
Self: Sized,
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::start_named_pipe(config, addr, codec).await
|
||||
}
|
||||
|
||||
/// Start a new server at the specified pipe address using the given codec
|
||||
pub async fn start_named_pipe<A, C>(
|
||||
config: DistantManagerConfig,
|
||||
addr: A,
|
||||
codec: C,
|
||||
) -> io::Result<WindowsPipeServerRef>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + Sync + 'static,
|
||||
{
|
||||
let a = addr.as_ref();
|
||||
let listener = WindowsPipeListener::bind(a)?;
|
||||
let addr = listener.addr().to_os_string();
|
||||
|
||||
let listener = MappedListener::new(listener, move |transport| {
|
||||
let transport = FramedTransport::new(transport, codec.clone());
|
||||
transport.into_split()
|
||||
});
|
||||
let inner = DistantManager::start(config, listener)?;
|
||||
Ok(WindowsPipeServerRef::new(addr, Box::new(inner)))
|
||||
}
|
||||
}
|
@ -1,68 +0,0 @@
|
||||
use crate::{
|
||||
data::Map, manager::data::Destination, DistantMsg, DistantRequestData, DistantResponseData,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use distant_net::{AuthClient, Request, Response, TypedAsyncRead, TypedAsyncWrite};
|
||||
use std::{future::Future, io};
|
||||
|
||||
pub type BoxedDistantWriter =
|
||||
Box<dyn TypedAsyncWrite<Request<DistantMsg<DistantRequestData>>> + Send>;
|
||||
pub type BoxedDistantReader =
|
||||
Box<dyn TypedAsyncRead<Response<DistantMsg<DistantResponseData>>> + Send>;
|
||||
pub type BoxedDistantWriterReader = (BoxedDistantWriter, BoxedDistantReader);
|
||||
pub type BoxedLaunchHandler = Box<dyn LaunchHandler>;
|
||||
pub type BoxedConnectHandler = Box<dyn ConnectHandler>;
|
||||
|
||||
/// Used to launch a server at the specified destination, returning some result as a vec of bytes
|
||||
#[async_trait]
|
||||
pub trait LaunchHandler: Send + Sync {
|
||||
async fn launch(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
options: &Map,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<Destination>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, R> LaunchHandler for F
|
||||
where
|
||||
F: for<'a> Fn(&'a Destination, &'a Map, &'a mut AuthClient) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = io::Result<Destination>> + Send + 'static,
|
||||
{
|
||||
async fn launch(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
options: &Map,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<Destination> {
|
||||
self(destination, options, auth_client).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to connect to a destination, returning a connected reader and writer pair
|
||||
#[async_trait]
|
||||
pub trait ConnectHandler: Send + Sync {
|
||||
async fn connect(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
options: &Map,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<BoxedDistantWriterReader>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, R> ConnectHandler for F
|
||||
where
|
||||
F: for<'a> Fn(&'a Destination, &'a Map, &'a mut AuthClient) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = io::Result<BoxedDistantWriterReader>> + Send + 'static,
|
||||
{
|
||||
async fn connect(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
options: &Map,
|
||||
auth_client: &mut AuthClient,
|
||||
) -> io::Result<BoxedDistantWriterReader> {
|
||||
self(destination, options, auth_client).await
|
||||
}
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
use super::{BoxedConnectHandler, BoxedLaunchHandler, ConnectHandler, LaunchHandler};
|
||||
use distant_net::{ServerRef, ServerState};
|
||||
use std::{collections::HashMap, io, sync::Weak};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Reference to a distant manager's server instance
|
||||
pub struct DistantManagerRef {
|
||||
/// Mapping of "scheme" -> handler
|
||||
pub(crate) launch_handlers: Weak<RwLock<HashMap<String, BoxedLaunchHandler>>>,
|
||||
|
||||
/// Mapping of "scheme" -> handler
|
||||
pub(crate) connect_handlers: Weak<RwLock<HashMap<String, BoxedConnectHandler>>>,
|
||||
|
||||
pub(crate) inner: Box<dyn ServerRef>,
|
||||
}
|
||||
|
||||
impl DistantManagerRef {
|
||||
/// Registers a new [`LaunchHandler`] for the specified scheme (e.g. "distant" or "ssh")
|
||||
pub async fn register_launch_handler(
|
||||
&self,
|
||||
scheme: impl Into<String>,
|
||||
handler: impl LaunchHandler + 'static,
|
||||
) -> io::Result<()> {
|
||||
let handlers = Weak::upgrade(&self.launch_handlers).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handler reference is no longer available",
|
||||
)
|
||||
})?;
|
||||
|
||||
handlers
|
||||
.write()
|
||||
.await
|
||||
.insert(scheme.into(), Box::new(handler));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Registers a new [`ConnectHandler`] for the specified scheme (e.g. "distant" or "ssh")
|
||||
pub async fn register_connect_handler(
|
||||
&self,
|
||||
scheme: impl Into<String>,
|
||||
handler: impl ConnectHandler + 'static,
|
||||
) -> io::Result<()> {
|
||||
let handlers = Weak::upgrade(&self.connect_handlers).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handler reference is no longer available",
|
||||
)
|
||||
})?;
|
||||
|
||||
handlers
|
||||
.write()
|
||||
.await
|
||||
.insert(scheme.into(), Box::new(handler));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerRef for DistantManagerRef {
|
||||
fn state(&self) -> &ServerState {
|
||||
self.inner.state()
|
||||
}
|
||||
|
||||
fn is_finished(&self) -> bool {
|
||||
self.inner.is_finished()
|
||||
}
|
||||
|
||||
fn abort(&self) {
|
||||
self.inner.abort();
|
||||
}
|
||||
}
|
@ -1,96 +0,0 @@
|
||||
use distant_core::{
|
||||
net::{FramedTransport, InmemoryTransport, IntoSplit, OneshotListener, PlainCodec},
|
||||
BoxedDistantReader, BoxedDistantWriter, Destination, DistantApiServer, DistantChannelExt,
|
||||
DistantManager, DistantManagerClient, DistantManagerClientConfig, DistantManagerConfig, Map,
|
||||
};
|
||||
use std::io;
|
||||
|
||||
/// Creates a client transport and server listener for our tests
|
||||
/// that are connected together
|
||||
async fn setup() -> (
|
||||
FramedTransport<InmemoryTransport, PlainCodec>,
|
||||
OneshotListener<FramedTransport<InmemoryTransport, PlainCodec>>,
|
||||
) {
|
||||
let (t1, t2) = InmemoryTransport::pair(100);
|
||||
|
||||
let listener = OneshotListener::from_value(FramedTransport::new(t2, PlainCodec));
|
||||
let transport = FramedTransport::new(t1, PlainCodec);
|
||||
(transport, listener)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_be_able_to_establish_a_single_connection_and_communicate() {
|
||||
let (transport, listener) = setup().await;
|
||||
|
||||
let config = DistantManagerConfig::default();
|
||||
let manager_ref = DistantManager::start(config, listener).expect("Failed to start manager");
|
||||
|
||||
// NOTE: To pass in a raw function, we HAVE to specify the types of the parameters manually,
|
||||
// otherwise we get a compilation error about lifetime mismatches
|
||||
manager_ref
|
||||
.register_connect_handler("scheme", |_: &_, _: &_, _: &mut _| async {
|
||||
use distant_core::net::ServerExt;
|
||||
let (t1, t2) = FramedTransport::pair(100);
|
||||
|
||||
// Spawn a server on one end
|
||||
let _ = DistantApiServer::local(Default::default())
|
||||
.unwrap()
|
||||
.start(OneshotListener::from_value(t2.into_split()))?;
|
||||
|
||||
// Create a reader/writer pair on the other end
|
||||
let (writer, reader) = t1.into_split();
|
||||
let writer: BoxedDistantWriter = Box::new(writer);
|
||||
let reader: BoxedDistantReader = Box::new(reader);
|
||||
Ok((writer, reader))
|
||||
})
|
||||
.await
|
||||
.expect("Failed to register handler");
|
||||
|
||||
let config = DistantManagerClientConfig::with_empty_prompts();
|
||||
let mut client =
|
||||
DistantManagerClient::new(config, transport).expect("Failed to connect to manager");
|
||||
|
||||
// Test establishing a connection to some remote server
|
||||
let id = client
|
||||
.connect(
|
||||
"scheme://host".parse::<Destination>().unwrap(),
|
||||
"key=value".parse::<Map>().unwrap(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to connect to a remote server");
|
||||
|
||||
// Test retrieving list of connections
|
||||
let list = client
|
||||
.list()
|
||||
.await
|
||||
.expect("Failed to get list of connections");
|
||||
assert_eq!(list.len(), 1);
|
||||
assert_eq!(list.get(&id).unwrap().to_string(), "scheme://host");
|
||||
|
||||
// Test retrieving information
|
||||
let info = client
|
||||
.info(id)
|
||||
.await
|
||||
.expect("Failed to get info about connection");
|
||||
assert_eq!(info.id, id);
|
||||
assert_eq!(info.destination.to_string(), "scheme://host");
|
||||
assert_eq!(info.options, "key=value".parse::<Map>().unwrap());
|
||||
|
||||
// Create a new channel and request some data
|
||||
let mut channel = client
|
||||
.open_channel(id)
|
||||
.await
|
||||
.expect("Failed to open channel");
|
||||
let _ = channel
|
||||
.system_info()
|
||||
.await
|
||||
.expect("Failed to get system information");
|
||||
|
||||
// Test killing a connection
|
||||
client.kill(id).await.expect("Failed to kill connection");
|
||||
|
||||
// Test getting an error to ensure that serialization of that data works,
|
||||
// which we do by trying to access a connection that no longer exists
|
||||
let err = client.info(id).await.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::NotConnected);
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Initializes logging (should only call once)
|
||||
pub fn init_logging(path: impl Into<PathBuf>) -> flexi_logger::LoggerHandle {
|
||||
use flexi_logger::{FileSpec, LevelFilter, LogSpecification, Logger};
|
||||
let modules = &["distant", "distant_core", "distant_ssh2"];
|
||||
|
||||
// Disable logging for everything but our binary, which is based on verbosity
|
||||
let mut builder = LogSpecification::builder();
|
||||
builder.default(LevelFilter::Off);
|
||||
|
||||
// For each module, configure logging
|
||||
for module in modules {
|
||||
builder.module(module, LevelFilter::Trace);
|
||||
}
|
||||
|
||||
// Create our logger, but don't initialize yet
|
||||
let logger = Logger::with(builder.build())
|
||||
.format_for_files(flexi_logger::opt_format)
|
||||
.log_to_file(FileSpec::try_from(path).expect("Failed to create log file spec"));
|
||||
|
||||
logger.start().expect("Failed to initialize logger")
|
||||
}
|
@ -1,122 +0,0 @@
|
||||
use derive_more::Display;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
mod client;
|
||||
pub use client::*;
|
||||
|
||||
mod handshake;
|
||||
pub use handshake::*;
|
||||
|
||||
mod server;
|
||||
pub use server::*;
|
||||
|
||||
/// Represents authentication messages that can be sent over the wire
|
||||
///
|
||||
/// NOTE: Must use serde's content attribute with the tag attribute. Just the tag attribute will
|
||||
/// cause deserialization to fail
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type", content = "data")]
|
||||
pub enum Auth {
|
||||
/// Represents a request to perform an authentication handshake,
|
||||
/// providing the public key and salt from one side in order to
|
||||
/// derive the shared key
|
||||
#[serde(rename = "auth_handshake")]
|
||||
Handshake {
|
||||
/// Bytes of the public key
|
||||
#[serde(with = "serde_bytes")]
|
||||
public_key: PublicKeyBytes,
|
||||
|
||||
/// Randomly generated salt
|
||||
#[serde(with = "serde_bytes")]
|
||||
salt: Salt,
|
||||
},
|
||||
|
||||
/// Represents the bytes of an encrypted message
|
||||
///
|
||||
/// Underneath, will be one of either [`AuthRequest`] or [`AuthResponse`]
|
||||
#[serde(rename = "auth_msg")]
|
||||
Msg {
|
||||
#[serde(with = "serde_bytes")]
|
||||
encrypted_payload: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Represents authentication messages that act as initiators such as providing
|
||||
/// a challenge, verifying information, presenting information, or highlighting an error
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthRequest {
|
||||
/// Represents a challenge comprising a series of questions to be presented
|
||||
Challenge {
|
||||
questions: Vec<AuthQuestion>,
|
||||
options: HashMap<String, String>,
|
||||
},
|
||||
|
||||
/// Represents an ask to verify some information
|
||||
Verify { kind: AuthVerifyKind, text: String },
|
||||
|
||||
/// Represents some information to be presented
|
||||
Info { text: String },
|
||||
|
||||
/// Represents some error that occurred
|
||||
Error { kind: AuthErrorKind, text: String },
|
||||
}
|
||||
|
||||
/// Represents authentication messages that are responses to auth requests such
|
||||
/// as answers to challenges or verifying information
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthResponse {
|
||||
/// Represents the answers to a previously-asked challenge
|
||||
Challenge { answers: Vec<String> },
|
||||
|
||||
/// Represents the answer to a previously-asked verify
|
||||
Verify { valid: bool },
|
||||
}
|
||||
|
||||
/// Represents the type of verification being requested
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[non_exhaustive]
|
||||
pub enum AuthVerifyKind {
|
||||
/// An ask to verify the host such as with SSH
|
||||
#[display(fmt = "host")]
|
||||
Host,
|
||||
}
|
||||
|
||||
/// Represents a single question in a challenge
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct AuthQuestion {
|
||||
/// The text of the question
|
||||
pub text: String,
|
||||
|
||||
/// Any options information specific to a particular auth domain
|
||||
/// such as including a username and instructions for SSH authentication
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl AuthQuestion {
|
||||
/// Creates a new question without any options data
|
||||
pub fn new(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
text: text.into(),
|
||||
options: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the type of error encountered during authentication
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AuthErrorKind {
|
||||
/// When the answer(s) to a challenge do not pass authentication
|
||||
FailedChallenge,
|
||||
|
||||
/// When verification during authentication fails
|
||||
/// (e.g. a host is not allowed or blocked)
|
||||
FailedVerification,
|
||||
|
||||
/// When the error is unknown
|
||||
Unknown,
|
||||
}
|
@ -1,817 +0,0 @@
|
||||
use crate::{
|
||||
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Client,
|
||||
Codec, Handshake, XChaCha20Poly1305Codec,
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
|
||||
pub struct AuthClient {
|
||||
inner: Client<Auth, Auth>,
|
||||
codec: Option<XChaCha20Poly1305Codec>,
|
||||
jit_handshake: bool,
|
||||
}
|
||||
|
||||
impl From<Client<Auth, Auth>> for AuthClient {
|
||||
fn from(client: Client<Auth, Auth>) -> Self {
|
||||
Self {
|
||||
inner: client,
|
||||
codec: None,
|
||||
jit_handshake: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthClient {
|
||||
/// Sends a request to the server to establish an encrypted connection
|
||||
pub async fn handshake(&mut self) -> io::Result<()> {
|
||||
let handshake = Handshake::default();
|
||||
|
||||
let response = self
|
||||
.inner
|
||||
.send(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let key = handshake.handshake(public_key, salt)?;
|
||||
self.codec.replace(XChaCha20Poly1305Codec::new(&key));
|
||||
Ok(())
|
||||
}
|
||||
Auth::Msg { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected encrypted message during handshake",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a handshake only if jit is enabled and no handshake has succeeded yet
|
||||
async fn jit_handshake(&mut self) -> io::Result<()> {
|
||||
if self.will_jit_handshake() && !self.is_ready() {
|
||||
self.handshake().await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if client has successfully performed a handshake
|
||||
/// and is ready to communicate with the server
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.codec.is_some()
|
||||
}
|
||||
|
||||
/// Returns true if this client will perform a handshake just-in-time (JIT) prior to making a
|
||||
/// request in the scenario where the client has not already performed a handshake
|
||||
#[inline]
|
||||
pub fn will_jit_handshake(&self) -> bool {
|
||||
self.jit_handshake
|
||||
}
|
||||
|
||||
/// Sets the jit flag on this client with `true` indicating that this client will perform a
|
||||
/// handshake just-in-time (JIT) prior to making a request in the scenario where the client has
|
||||
/// not already performed a handshake
|
||||
#[inline]
|
||||
pub fn set_jit_handshake(&mut self, flag: bool) {
|
||||
self.jit_handshake = flag;
|
||||
}
|
||||
|
||||
/// Provides a challenge to the server and returns the answers to the questions
|
||||
/// asked by the client
|
||||
pub async fn challenge(
|
||||
&mut self,
|
||||
questions: Vec<AuthQuestion>,
|
||||
options: HashMap<String, String>,
|
||||
) -> io::Result<Vec<String>> {
|
||||
trace!(
|
||||
"AuthClient::challenge(questions = {:?}, options = {:?})",
|
||||
questions,
|
||||
options
|
||||
);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Challenge { questions, options };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match self.decrypt_and_deserialize(&encrypted_payload)? {
|
||||
AuthResponse::Challenge { answers } => Ok(answers),
|
||||
AuthResponse::Verify { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected verify response during challenge",
|
||||
)),
|
||||
}
|
||||
}
|
||||
Auth::Handshake { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected handshake during challenge",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides a verification request to the server and returns whether or not
|
||||
/// the server approved
|
||||
pub async fn verify(&mut self, kind: AuthVerifyKind, text: String) -> io::Result<bool> {
|
||||
trace!("AuthClient::verify(kind = {:?}, text = {:?})", kind, text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Verify { kind, text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
let response = self.inner.send(Auth::Msg { encrypted_payload }).await?;
|
||||
|
||||
match response.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match self.decrypt_and_deserialize(&encrypted_payload)? {
|
||||
AuthResponse::Verify { valid } => Ok(valid),
|
||||
AuthResponse::Challenge { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected challenge response during verify",
|
||||
)),
|
||||
}
|
||||
}
|
||||
Auth::Handshake { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Got unexpected handshake during verify",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides information to the server to use as it pleases with no response expected
|
||||
pub async fn info(&mut self, text: String) -> io::Result<()> {
|
||||
trace!("AuthClient::info(text = {:?})", text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Info { text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
self.inner.fire(Auth::Msg { encrypted_payload }).await
|
||||
}
|
||||
|
||||
/// Provides an error to the server to use as it pleases with no response expected
|
||||
pub async fn error(&mut self, kind: AuthErrorKind, text: String) -> io::Result<()> {
|
||||
trace!("AuthClient::error(kind = {:?}, text = {:?})", kind, text);
|
||||
|
||||
// Perform JIT handshake if enabled
|
||||
self.jit_handshake().await?;
|
||||
|
||||
let payload = AuthRequest::Error { kind, text };
|
||||
let encrypted_payload = self.serialize_and_encrypt(&payload)?;
|
||||
self.inner.fire(Auth::Msg { encrypted_payload }).await
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt(&mut self, payload: &AuthRequest) -> io::Result<Vec<u8>> {
|
||||
let codec = self.codec.as_mut().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (client encrypt message)",
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize(&mut self, payload: &[u8]) -> io::Result<AuthResponse> {
|
||||
let codec = self.codec.as_mut().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (client decrypt message)",
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Client, FramedTransport, Request, Response, TypedAsyncRead, TypedAsyncWrite};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
const TIMEOUT_MILLIS: u64 = 100;
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_should_fail_if_get_unexpected_response_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.handshake().await });
|
||||
|
||||
// Get the request, but send a bad response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Handshake { .. } => server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: Vec::new(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap(),
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
}
|
||||
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Handshake succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.challenge(Vec::new(), HashMap::new()).await });
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Challenge succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_fail_if_receive_wrong_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.challenge(
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Challenge { .. } => {
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Verify { valid: true },
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Challenge succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn challenge_should_return_answers_received_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.challenge(
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Challenge { questions, options } => {
|
||||
assert_eq!(
|
||||
questions,
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key2".to_string(), "value2".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
options,
|
||||
vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
);
|
||||
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Challenge {
|
||||
answers: vec![
|
||||
"answer1".to_string(),
|
||||
"answer2".to_string(),
|
||||
],
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
let answers = task.await.unwrap().unwrap();
|
||||
assert_eq!(answers, vec!["answer1".to_string(), "answer2".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Verify succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_fail_if_receive_wrong_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a verify request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Verify { .. } => {
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Challenge {
|
||||
answers: Vec::new(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Verify succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_should_return_valid_bool_received_from_server() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.verify(AuthVerifyKind::Host, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a challenge request and send back wrong response
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Verify { kind, text } => {
|
||||
assert_eq!(kind, AuthVerifyKind::Host);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthResponse::Verify { valid: true },
|
||||
)
|
||||
.unwrap(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
let valid = task.await.unwrap().unwrap();
|
||||
assert!(valid, "Got verify response, but valid was set incorrectly");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move { client.info("some text".to_string()).await });
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Info succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info_should_send_the_server_a_request_but_not_wait_for_a_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client.info("some text".to_string()).await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a request
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Info { text } => {
|
||||
assert_eq!(text, "some text");
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
task.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_should_fail_if_handshake_not_finished() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client
|
||||
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a request, failing if we get one as the failure
|
||||
// should have prevented sending anything, but we should
|
||||
tokio::select! {
|
||||
x = TypedAsyncRead::<Request<Auth>>::read(&mut server) => {
|
||||
match x {
|
||||
Ok(Some(x)) => panic!("Unexpectedly resolved: {:?}", x),
|
||||
Ok(None) => {},
|
||||
Err(x) => panic!("Unexpectedly failed on server side: {}", x),
|
||||
}
|
||||
},
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {
|
||||
panic!("Should have gotten server closure as part of client exit");
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we got an error with the method
|
||||
let result = task.await.unwrap();
|
||||
assert!(result.is_err(), "Error succeeded unexpectedly")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_should_send_the_server_a_request_but_not_wait_for_a_response() {
|
||||
let (t, mut server) = FramedTransport::make_test_pair();
|
||||
let mut client = AuthClient::from(Client::from_framed_transport(t).unwrap());
|
||||
|
||||
// We start a separate task for the client to avoid blocking since
|
||||
// we also need to receive the client's request and respond
|
||||
let task = tokio::spawn(async move {
|
||||
client.handshake().await.unwrap();
|
||||
client
|
||||
.error(AuthErrorKind::FailedChallenge, "some text".to_string())
|
||||
.await
|
||||
});
|
||||
|
||||
// Wait for a handshake request and set up our encryption codec
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
let mut codec = match request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
let handshake = Handshake::default();
|
||||
let key = handshake.handshake(public_key, salt).unwrap();
|
||||
server
|
||||
.write(Response::new(
|
||||
request.id,
|
||||
Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
XChaCha20Poly1305Codec::new(&key)
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Wait for a request
|
||||
let request: Request<Auth> = server.read().await.unwrap().unwrap();
|
||||
match request.payload {
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthRequest::Error { kind, text } => {
|
||||
assert_eq!(kind, AuthErrorKind::FailedChallenge);
|
||||
assert_eq!(text, "some text");
|
||||
}
|
||||
_ => panic!("Server received wrong request type"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Server received unexpected payload"),
|
||||
};
|
||||
|
||||
// Verify that we got the right results
|
||||
task.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
async fn wait_ms(ms: u64) {
|
||||
use std::time::Duration;
|
||||
tokio::time::sleep(Duration::from_millis(ms)).await;
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt<T: Serialize>(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &T,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize<T: DeserializeOwned>(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &[u8],
|
||||
) -> io::Result<T> {
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<T>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,653 +0,0 @@
|
||||
use crate::{
|
||||
utils, Auth, AuthErrorKind, AuthQuestion, AuthRequest, AuthResponse, AuthVerifyKind, Codec,
|
||||
Handshake, Server, ServerCtx, XChaCha20Poly1305Codec,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use log::*;
|
||||
use std::{collections::HashMap, io};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Type signature for a dynamic on_challenge function
|
||||
pub type AuthChallengeFn =
|
||||
dyn Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync;
|
||||
|
||||
/// Type signature for a dynamic on_verify function
|
||||
pub type AuthVerifyFn = dyn Fn(AuthVerifyKind, String) -> bool + Send + Sync;
|
||||
|
||||
/// Type signature for a dynamic on_info function
|
||||
pub type AuthInfoFn = dyn Fn(String) + Send + Sync;
|
||||
|
||||
/// Type signature for a dynamic on_error function
|
||||
pub type AuthErrorFn = dyn Fn(AuthErrorKind, String) + Send + Sync;
|
||||
|
||||
/// Represents an [`AuthServer`] where all handlers are stored on the heap
|
||||
pub type HeapAuthServer =
|
||||
AuthServer<Box<AuthChallengeFn>, Box<AuthVerifyFn>, Box<AuthInfoFn>, Box<AuthErrorFn>>;
|
||||
|
||||
/// Server that handles authentication
|
||||
pub struct AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
|
||||
where
|
||||
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
|
||||
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
|
||||
InfoFn: Fn(String) + Send + Sync,
|
||||
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
|
||||
{
|
||||
pub on_challenge: ChallengeFn,
|
||||
pub on_verify: VerifyFn,
|
||||
pub on_info: InfoFn,
|
||||
pub on_error: ErrorFn,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<ChallengeFn, VerifyFn, InfoFn, ErrorFn> Server
|
||||
for AuthServer<ChallengeFn, VerifyFn, InfoFn, ErrorFn>
|
||||
where
|
||||
ChallengeFn: Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync,
|
||||
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync,
|
||||
InfoFn: Fn(String) + Send + Sync,
|
||||
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync,
|
||||
{
|
||||
type Request = Auth;
|
||||
type Response = Auth;
|
||||
type LocalData = RwLock<Option<XChaCha20Poly1305Codec>>;
|
||||
|
||||
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
|
||||
let reply = ctx.reply.clone();
|
||||
|
||||
match ctx.request.payload {
|
||||
Auth::Handshake { public_key, salt } => {
|
||||
trace!(
|
||||
"Received handshake request from client, request id = {}",
|
||||
ctx.request.id
|
||||
);
|
||||
let handshake = Handshake::default();
|
||||
match handshake.handshake(public_key, salt) {
|
||||
Ok(key) => {
|
||||
ctx.local_data
|
||||
.write()
|
||||
.await
|
||||
.replace(XChaCha20Poly1305Codec::new(&key));
|
||||
|
||||
trace!(
|
||||
"Sending reciprocal handshake to client, response origin id = {}",
|
||||
ctx.request.id
|
||||
);
|
||||
if let Err(x) = reply
|
||||
.send(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Auth::Msg {
|
||||
ref encrypted_payload,
|
||||
} => {
|
||||
trace!(
|
||||
"Received auth msg, encrypted payload size = {}",
|
||||
encrypted_payload.len()
|
||||
);
|
||||
|
||||
// Attempt to decrypt the message so we can understand what to do
|
||||
let request = match ctx.local_data.write().await.as_mut() {
|
||||
Some(codec) => {
|
||||
let mut payload = BytesMut::from(encrypted_payload.as_slice());
|
||||
match codec.decode(&mut payload) {
|
||||
Ok(Some(payload)) => {
|
||||
utils::deserialize_from_slice::<AuthRequest>(&payload)
|
||||
}
|
||||
Ok(None) => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
Err(x) => Err(x),
|
||||
}
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (server decrypt message)",
|
||||
)),
|
||||
};
|
||||
|
||||
let response = match request {
|
||||
Ok(request) => match request {
|
||||
AuthRequest::Challenge { questions, options } => {
|
||||
trace!("Received challenge request");
|
||||
trace!("questions = {:?}", questions);
|
||||
trace!("options = {:?}", options);
|
||||
|
||||
let answers = (self.on_challenge)(questions, options);
|
||||
AuthResponse::Challenge { answers }
|
||||
}
|
||||
AuthRequest::Verify { kind, text } => {
|
||||
trace!("Received verify request");
|
||||
trace!("kind = {:?}", kind);
|
||||
trace!("text = {:?}", text);
|
||||
|
||||
let valid = (self.on_verify)(kind, text);
|
||||
AuthResponse::Verify { valid }
|
||||
}
|
||||
AuthRequest::Info { text } => {
|
||||
trace!("Received info request");
|
||||
trace!("text = {:?}", text);
|
||||
|
||||
(self.on_info)(text);
|
||||
return;
|
||||
}
|
||||
AuthRequest::Error { kind, text } => {
|
||||
trace!("Received error request");
|
||||
trace!("kind = {:?}", kind);
|
||||
trace!("text = {:?}", text);
|
||||
|
||||
(self.on_error)(kind, text);
|
||||
return;
|
||||
}
|
||||
},
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Serialize and encrypt the message before sending it back
|
||||
let encrypted_payload = match ctx.local_data.write().await.as_mut() {
|
||||
Some(codec) => {
|
||||
let mut encrypted_payload = BytesMut::new();
|
||||
|
||||
// Convert the response into bytes for us to send back
|
||||
match utils::serialize_to_vec(&response) {
|
||||
Ok(bytes) => match codec.encode(&bytes, &mut encrypted_payload) {
|
||||
Ok(_) => Ok(encrypted_payload.freeze().to_vec()),
|
||||
Err(x) => Err(x),
|
||||
},
|
||||
Err(x) => Err(x),
|
||||
}
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Handshake must be performed first (server encrypt messaage)",
|
||||
)),
|
||||
};
|
||||
|
||||
match encrypted_payload {
|
||||
Ok(encrypted_payload) => {
|
||||
if let Err(x) = reply.send(Auth::Msg { encrypted_payload }).await {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("[Conn {}] {}", ctx.connection_id, x);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
IntoSplit, MpscListener, MpscTransport, Request, Response, ServerExt, ServerRef,
|
||||
TypedAsyncRead, TypedAsyncWrite,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
const TIMEOUT_MILLIS: u64 = 100;
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_not_reply_if_receive_encrypted_msg_without_handshake_first() {
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send an encrypted message before establishing a handshake
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: Vec::new(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_reply_to_handshake_request_with_new_public_key_and_salt() {
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Wait for a handshake response
|
||||
tokio::select! {
|
||||
x = t.read() => {
|
||||
let response = x.expect("Request failed").expect("Response missing");
|
||||
match response.payload {
|
||||
Auth::Handshake { .. } => {},
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
}
|
||||
}
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => panic!("Ran out of time waiting on response"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_not_reply_if_receive_invalid_encrypted_msg() {
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send a bad chunk of data
|
||||
let _codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: vec![1, 2, 3, 4],
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_challenge_request_and_reply() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */
|
||||
move |questions, options| {
|
||||
tx.try_send((questions, options)).unwrap();
|
||||
vec!["answer1".to_string(), "answer2".to_string()]
|
||||
},
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Challenge {
|
||||
questions: vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
],
|
||||
options: vec![("hello".to_string(), "world".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let (questions, options) = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(
|
||||
questions,
|
||||
vec![
|
||||
AuthQuestion::new("question1".to_string()),
|
||||
AuthQuestion {
|
||||
text: "question2".to_string(),
|
||||
options: vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
}
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
options,
|
||||
vec![("hello".to_string(), "world".to_string())]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
// Wait for a response and verify that it matches what we expect
|
||||
tokio::select! {
|
||||
x = t.read() => {
|
||||
let response = x.expect("Request failed").expect("Response missing");
|
||||
match response.payload {
|
||||
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthResponse::Challenge { answers } =>
|
||||
assert_eq!(
|
||||
answers,
|
||||
vec!["answer1".to_string(), "answer2".to_string()]
|
||||
),
|
||||
_ => panic!("Got wrong response for verify"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_verify_request_and_reply() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */
|
||||
move |kind, text| {
|
||||
tx.try_send((kind, text)).unwrap();
|
||||
true
|
||||
},
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Verify {
|
||||
kind: AuthVerifyKind::Host,
|
||||
text: "some text".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(kind, AuthVerifyKind::Host);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
// Wait for a response and verify that it matches what we expect
|
||||
tokio::select! {
|
||||
x = t.read() => {
|
||||
let response = x.expect("Request failed").expect("Response missing");
|
||||
match response.payload {
|
||||
Auth::Handshake { .. } => panic!("Received unexpected handshake"),
|
||||
Auth::Msg { encrypted_payload } => {
|
||||
match decrypt_and_deserialize(&mut codec, &encrypted_payload).unwrap() {
|
||||
AuthResponse::Verify { valid } =>
|
||||
assert!(valid, "Got verify, but valid was wrong"),
|
||||
_ => panic!("Got wrong response for verify"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_info_request() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */
|
||||
move |text| {
|
||||
tx.try_send(text).unwrap();
|
||||
},
|
||||
/* on_error */ |_, _| {},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Info {
|
||||
text: "some text".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let text = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_invoke_appropriate_function_when_receive_error_request() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
let (mut t, _) = spawn_auth_server(
|
||||
/* on_challenge */ |_, _| Vec::new(),
|
||||
/* on_verify */ |_, _| false,
|
||||
/* on_info */ |_| {},
|
||||
/* on_error */
|
||||
move |kind, text| {
|
||||
tx.try_send((kind, text)).unwrap();
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("Failed to spawn server");
|
||||
|
||||
// Send a handshake
|
||||
let handshake = Handshake::default();
|
||||
t.write(Request::new(Auth::Handshake {
|
||||
public_key: handshake.pk_bytes(),
|
||||
salt: *handshake.salt(),
|
||||
}))
|
||||
.await
|
||||
.expect("Failed to send request to server");
|
||||
|
||||
// Complete handshake
|
||||
let key = match t.read().await.unwrap().unwrap().payload {
|
||||
Auth::Handshake { public_key, salt } => handshake.handshake(public_key, salt).unwrap(),
|
||||
Auth::Msg { .. } => panic!("Received unexpected encryped message during handshake"),
|
||||
};
|
||||
|
||||
// Send an error request
|
||||
let mut codec = XChaCha20Poly1305Codec::new(&key);
|
||||
t.write(Request::new(Auth::Msg {
|
||||
encrypted_payload: serialize_and_encrypt(
|
||||
&mut codec,
|
||||
&AuthRequest::Error {
|
||||
kind: AuthErrorKind::FailedChallenge,
|
||||
text: "some text".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that the handler was triggered
|
||||
let (kind, text) = rx.recv().await.expect("Channel closed unexpectedly");
|
||||
assert_eq!(kind, AuthErrorKind::FailedChallenge);
|
||||
assert_eq!(text, "some text");
|
||||
|
||||
// Wait for a response, failing if we get one
|
||||
tokio::select! {
|
||||
x = t.read() => panic!("Unexpectedly resolved: {:?}", x),
|
||||
_ = wait_ms(TIMEOUT_MILLIS) => {}
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_ms(ms: u64) {
|
||||
use std::time::Duration;
|
||||
tokio::time::sleep(Duration::from_millis(ms)).await;
|
||||
}
|
||||
|
||||
fn serialize_and_encrypt(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &AuthRequest,
|
||||
) -> io::Result<Vec<u8>> {
|
||||
let mut encryped_payload = BytesMut::new();
|
||||
let payload = utils::serialize_to_vec(payload)?;
|
||||
codec.encode(&payload, &mut encryped_payload)?;
|
||||
Ok(encryped_payload.freeze().to_vec())
|
||||
}
|
||||
|
||||
fn decrypt_and_deserialize(
|
||||
codec: &mut XChaCha20Poly1305Codec,
|
||||
payload: &[u8],
|
||||
) -> io::Result<AuthResponse> {
|
||||
let mut payload = BytesMut::from(payload);
|
||||
match codec.decode(&mut payload)? {
|
||||
Some(payload) => utils::deserialize_from_slice::<AuthResponse>(&payload),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Incomplete message received",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_auth_server<ChallengeFn, VerifyFn, InfoFn, ErrorFn>(
|
||||
on_challenge: ChallengeFn,
|
||||
on_verify: VerifyFn,
|
||||
on_info: InfoFn,
|
||||
on_error: ErrorFn,
|
||||
) -> io::Result<(
|
||||
MpscTransport<Request<Auth>, Response<Auth>>,
|
||||
Box<dyn ServerRef>,
|
||||
)>
|
||||
where
|
||||
ChallengeFn:
|
||||
Fn(Vec<AuthQuestion>, HashMap<String, String>) -> Vec<String> + Send + Sync + 'static,
|
||||
VerifyFn: Fn(AuthVerifyKind, String) -> bool + Send + Sync + 'static,
|
||||
InfoFn: Fn(String) + Send + Sync + 'static,
|
||||
ErrorFn: Fn(AuthErrorKind, String) + Send + Sync + 'static,
|
||||
{
|
||||
let server = AuthServer {
|
||||
on_challenge,
|
||||
on_verify,
|
||||
on_info,
|
||||
on_error,
|
||||
};
|
||||
|
||||
// Create a test listener where we will forward a connection
|
||||
let (tx, listener) = MpscListener::channel(100);
|
||||
|
||||
// Make bounded transport pair and send off one of them to act as our connection
|
||||
let (transport, connection) = MpscTransport::<Request<Auth>, Response<Auth>>::pair(100);
|
||||
tx.send(connection.into_split())
|
||||
.await
|
||||
.expect("Failed to feed listener a connection");
|
||||
|
||||
let server = server.start(listener)?;
|
||||
Ok((transport, server))
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,142 @@
|
||||
mod tcp;
|
||||
pub use tcp::*;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
|
||||
#[cfg(windows)]
|
||||
mod windows;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub use windows::*;
|
||||
|
||||
use crate::client::{Client, ReconnectStrategy, UntypedClient};
|
||||
use crate::common::{authentication::AuthHandler, Connection, Transport};
|
||||
use async_trait::async_trait;
|
||||
use std::{convert, io, time::Duration};
|
||||
|
||||
/// Interface that performs the connection to produce a [`Transport`] for use by the [`Client`].
|
||||
#[async_trait]
|
||||
pub trait Connector {
|
||||
/// Type of transport produced by the connection.
|
||||
type Transport: Transport + 'static;
|
||||
|
||||
async fn connect(self) -> io::Result<Self::Transport>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Transport + 'static> Connector for T {
|
||||
type Transport = T;
|
||||
|
||||
async fn connect(self) -> io::Result<Self::Transport> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for a [`Client`] or [`UntypedClient`].
|
||||
pub struct ClientBuilder<H, C> {
|
||||
auth_handler: H,
|
||||
connector: C,
|
||||
reconnect_strategy: ReconnectStrategy,
|
||||
timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl<H, C> ClientBuilder<H, C> {
|
||||
pub fn auth_handler<U>(self, auth_handler: U) -> ClientBuilder<U, C> {
|
||||
ClientBuilder {
|
||||
auth_handler,
|
||||
connector: self.connector,
|
||||
reconnect_strategy: self.reconnect_strategy,
|
||||
timeout: self.timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connector<U>(self, connector: U) -> ClientBuilder<H, U> {
|
||||
ClientBuilder {
|
||||
auth_handler: self.auth_handler,
|
||||
connector,
|
||||
reconnect_strategy: self.reconnect_strategy,
|
||||
timeout: self.timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> ClientBuilder<H, C> {
|
||||
ClientBuilder {
|
||||
auth_handler: self.auth_handler,
|
||||
connector: self.connector,
|
||||
reconnect_strategy,
|
||||
timeout: self.timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
|
||||
Self {
|
||||
auth_handler: self.auth_handler,
|
||||
connector: self.connector,
|
||||
reconnect_strategy: self.reconnect_strategy,
|
||||
timeout: timeout.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientBuilder<(), ()> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
auth_handler: (),
|
||||
reconnect_strategy: ReconnectStrategy::default(),
|
||||
connector: (),
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientBuilder<(), ()> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<H, C> ClientBuilder<H, C>
|
||||
where
|
||||
H: AuthHandler + Send,
|
||||
C: Connector,
|
||||
{
|
||||
/// Establishes a connection with a remote server using the configured [`Transport`]
|
||||
/// and other settings, returning a new [`UntypedClient`] instance once the connection
|
||||
/// is fully established and authenticated.
|
||||
pub async fn connect_untyped(self) -> io::Result<UntypedClient> {
|
||||
let auth_handler = self.auth_handler;
|
||||
let retry_strategy = self.reconnect_strategy;
|
||||
let timeout = self.timeout;
|
||||
|
||||
let f = async move {
|
||||
let transport = match timeout {
|
||||
Some(duration) => tokio::time::timeout(duration, self.connector.connect())
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)?,
|
||||
None => self.connector.connect().await?,
|
||||
};
|
||||
let connection = Connection::client(transport, auth_handler).await?;
|
||||
Ok(UntypedClient::spawn(connection, retry_strategy))
|
||||
};
|
||||
|
||||
match timeout {
|
||||
Some(duration) => tokio::time::timeout(duration, f)
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity),
|
||||
None => f.await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Establishes a connection with a remote server using the configured [`Transport`] and other
|
||||
/// settings, returning a new [`Client`] instance once the connection is fully established and
|
||||
/// authenticated.
|
||||
pub async fn connect<T, U>(self) -> io::Result<Client<T, U>> {
|
||||
Ok(self.connect_untyped().await?.into_typed_client())
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
use super::Connector;
|
||||
use crate::common::TcpTransport;
|
||||
use async_trait::async_trait;
|
||||
use std::io;
|
||||
use tokio::net::ToSocketAddrs;
|
||||
|
||||
/// Implementation of [`Connector`] to support connecting via TCP.
|
||||
pub struct TcpConnector<T> {
|
||||
addr: T,
|
||||
}
|
||||
|
||||
impl<T> TcpConnector<T> {
|
||||
pub fn new(addr: T) -> Self {
|
||||
Self { addr }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for TcpConnector<T> {
|
||||
fn from(addr: T) -> Self {
|
||||
Self::new(addr)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: ToSocketAddrs + Send> Connector for TcpConnector<T> {
|
||||
type Transport = TcpTransport;
|
||||
|
||||
async fn connect(self) -> io::Result<Self::Transport> {
|
||||
TcpTransport::connect(self.addr).await
|
||||
}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
use super::Connector;
|
||||
use crate::common::UnixSocketTransport;
|
||||
use async_trait::async_trait;
|
||||
use std::{io, path::PathBuf};
|
||||
|
||||
/// Implementation of [`Connector`] to support connecting via a Unix socket.
|
||||
pub struct UnixSocketConnector {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl UnixSocketConnector {
|
||||
pub fn new(path: impl Into<PathBuf>) -> Self {
|
||||
Self { path: path.into() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Into<PathBuf>> From<T> for UnixSocketConnector {
|
||||
fn from(path: T) -> Self {
|
||||
Self::new(path)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Connector for UnixSocketConnector {
|
||||
type Transport = UnixSocketTransport;
|
||||
|
||||
async fn connect(self) -> io::Result<Self::Transport> {
|
||||
UnixSocketTransport::connect(self.path).await
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
use super::Connector;
|
||||
use crate::common::WindowsPipeTransport;
|
||||
use async_trait::async_trait;
|
||||
use std::ffi::OsString;
|
||||
use std::io;
|
||||
|
||||
/// Implementation of [`Connector`] to support connecting via a Windows named pipe.
|
||||
pub struct WindowsPipeConnector {
|
||||
addr: OsString,
|
||||
pub(crate) local: bool,
|
||||
}
|
||||
|
||||
impl WindowsPipeConnector {
|
||||
/// Creates a new connector for a non-local pipe using the given `addr`.
|
||||
pub fn new(addr: impl Into<OsString>) -> Self {
|
||||
Self {
|
||||
addr: addr.into(),
|
||||
local: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new connector for a local pipe using the given `name`.
|
||||
pub fn local(name: impl Into<OsString>) -> Self {
|
||||
Self {
|
||||
addr: name.into(),
|
||||
local: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Into<OsString>> From<T> for WindowsPipeConnector {
|
||||
fn from(addr: T) -> Self {
|
||||
Self::new(addr)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Connector for WindowsPipeConnector {
|
||||
type Transport = WindowsPipeTransport;
|
||||
|
||||
async fn connect(self) -> io::Result<Self::Transport> {
|
||||
if self.local {
|
||||
let mut full_addr = OsString::from(r"\\.\pipe\");
|
||||
full_addr.push(self.addr);
|
||||
WindowsPipeTransport::connect(full_addr).await
|
||||
} else {
|
||||
WindowsPipeTransport::connect(self.addr).await
|
||||
}
|
||||
}
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
use crate::{Client, Codec, FramedTransport, TcpTransport};
|
||||
use async_trait::async_trait;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{convert, net::SocketAddr};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait TcpClientExt<T, U>
|
||||
where
|
||||
T: Serialize + Send + Sync,
|
||||
U: DeserializeOwned + Send + Sync,
|
||||
{
|
||||
/// Connect to a remote TCP server using the provided information
|
||||
async fn connect<C>(addr: SocketAddr, codec: C) -> io::Result<Client<T, U>>
|
||||
where
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a remote TCP server, timing out after duration has passed
|
||||
async fn connect_timeout<C>(
|
||||
addr: SocketAddr,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<Client<T, U>>
|
||||
where
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(addr, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> TcpClientExt<T, U> for Client<T, U>
|
||||
where
|
||||
T: Send + Sync + Serialize + 'static,
|
||||
U: Send + Sync + DeserializeOwned + 'static,
|
||||
{
|
||||
/// Connect to a remote TCP server using the provided information
|
||||
async fn connect<C>(addr: SocketAddr, codec: C) -> io::Result<Client<T, U>>
|
||||
where
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let transport = TcpTransport::connect(addr).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
Self::from_framed_transport(transport)
|
||||
}
|
||||
}
|
@ -1,54 +0,0 @@
|
||||
use crate::{Client, Codec, FramedTransport, IntoSplit, UnixSocketTransport};
|
||||
use async_trait::async_trait;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{convert, path::Path};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait UnixSocketClientExt<T, U>
|
||||
where
|
||||
T: Serialize + Send + Sync,
|
||||
U: DeserializeOwned + Send + Sync,
|
||||
{
|
||||
/// Connect to a proxy unix socket
|
||||
async fn connect<P, C>(path: P, codec: C) -> io::Result<Client<T, U>>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a proxy unix socket, timing out after duration has passed
|
||||
async fn connect_timeout<P, C>(
|
||||
path: P,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<Client<T, U>>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(path, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> UnixSocketClientExt<T, U> for Client<T, U>
|
||||
where
|
||||
T: Send + Sync + Serialize + 'static,
|
||||
U: Send + Sync + DeserializeOwned + 'static,
|
||||
{
|
||||
/// Connect to a proxy unix socket
|
||||
async fn connect<P, C>(path: P, codec: C) -> io::Result<Client<T, U>>
|
||||
where
|
||||
P: AsRef<Path> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let p = path.as_ref();
|
||||
let transport = UnixSocketTransport::connect(p).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
let (writer, reader) = transport.into_split();
|
||||
Ok(Client::new(writer, reader)?)
|
||||
}
|
||||
}
|
@ -1,86 +0,0 @@
|
||||
use crate::{Client, Codec, FramedTransport, IntoSplit, WindowsPipeTransport};
|
||||
use async_trait::async_trait;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{
|
||||
convert,
|
||||
ffi::{OsStr, OsString},
|
||||
};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait WindowsPipeClientExt<T, U>
|
||||
where
|
||||
T: Serialize + Send + Sync,
|
||||
U: DeserializeOwned + Send + Sync,
|
||||
{
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// using the given codec
|
||||
async fn connect<A, C>(addr: A, codec: C) -> io::Result<Client<T, U>>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static;
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// via `\\.\pipe\{name}` using the given codec
|
||||
async fn connect_local<N, C>(name: N, codec: C) -> io::Result<Client<T, U>>
|
||||
where
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::connect(addr, codec).await
|
||||
}
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// using the given codec, timing out after duration has passed
|
||||
async fn connect_timeout<A, C>(
|
||||
addr: A,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<Client<T, U>>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
tokio::time::timeout(duration, Self::connect(addr, codec))
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Connect to a server listening on a Windows pipe at the specified address
|
||||
/// via `\\.\pipe\{name}` using the given codec, timing out after duration has passed
|
||||
async fn connect_local_timeout<N, C>(
|
||||
name: N,
|
||||
codec: C,
|
||||
duration: Duration,
|
||||
) -> io::Result<Client<T, U>>
|
||||
where
|
||||
N: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let mut addr = OsString::from(r"\\.\pipe\");
|
||||
addr.push(name.as_ref());
|
||||
Self::connect_timeout(addr, codec, duration).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> WindowsPipeClientExt<T, U> for Client<T, U>
|
||||
where
|
||||
T: Send + Sync + Serialize + 'static,
|
||||
U: Send + Sync + DeserializeOwned + 'static,
|
||||
{
|
||||
async fn connect<A, C>(addr: A, codec: C) -> io::Result<Client<T, U>>
|
||||
where
|
||||
A: AsRef<OsStr> + Send,
|
||||
C: Codec + Send + 'static,
|
||||
{
|
||||
let a = addr.as_ref();
|
||||
let transport = WindowsPipeTransport::connect(a).await?;
|
||||
let transport = FramedTransport::new(transport, codec);
|
||||
let (writer, reader) = transport.into_split();
|
||||
Ok(Client::new(writer, reader)?)
|
||||
}
|
||||
}
|
@ -0,0 +1,208 @@
|
||||
use super::Reconnectable;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Represents the strategy to apply when attempting to reconnect the client to the server.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ReconnectStrategy {
|
||||
/// A retry strategy that will fail immediately if a reconnect is attempted.
|
||||
Fail,
|
||||
|
||||
/// A retry strategy driven by exponential back-off.
|
||||
ExponentialBackoff {
|
||||
/// Represents the initial time to wait between reconnect attempts.
|
||||
base: Duration,
|
||||
|
||||
/// Factor to use when modifying the retry time, used as a multiplier.
|
||||
factor: f64,
|
||||
|
||||
/// Represents the maximum duration to wait between attempts. None indicates no limit.
|
||||
max_duration: Option<Duration>,
|
||||
|
||||
/// Represents the maximum attempts to retry before failing. None indicates no limit.
|
||||
max_retries: Option<usize>,
|
||||
|
||||
/// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
|
||||
timeout: Option<Duration>,
|
||||
},
|
||||
|
||||
/// A retry strategy driven by the fibonacci series.
|
||||
FibonacciBackoff {
|
||||
/// Represents the initial time to wait between reconnect attempts.
|
||||
base: Duration,
|
||||
|
||||
/// Represents the maximum duration to wait between attempts. None indicates no limit.
|
||||
max_duration: Option<Duration>,
|
||||
|
||||
/// Represents the maximum attempts to retry before failing. None indicates no limit.
|
||||
max_retries: Option<usize>,
|
||||
|
||||
/// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
|
||||
timeout: Option<Duration>,
|
||||
},
|
||||
|
||||
/// A retry strategy driven by a fixed interval.
|
||||
FixedInterval {
|
||||
/// Represents the time between reconnect attempts.
|
||||
interval: Duration,
|
||||
|
||||
/// Represents the maximum attempts to retry before failing. None indicates no limit.
|
||||
max_retries: Option<usize>,
|
||||
|
||||
/// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
|
||||
timeout: Option<Duration>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for ReconnectStrategy {
|
||||
/// Creates a reconnect strategy that will immediately fail.
|
||||
fn default() -> Self {
|
||||
Self::Fail
|
||||
}
|
||||
}
|
||||
|
||||
impl ReconnectStrategy {
|
||||
pub async fn reconnect<T: Reconnectable>(&mut self, reconnectable: &mut T) -> io::Result<()> {
|
||||
// If our strategy is to immediately fail, do so
|
||||
if self.is_fail() {
|
||||
return Err(io::Error::from(io::ErrorKind::ConnectionAborted));
|
||||
}
|
||||
|
||||
// Keep track of last sleep length for use in adjustment
|
||||
let mut previous_sleep = None;
|
||||
let mut current_sleep = self.initial_sleep_duration();
|
||||
|
||||
// Keep track of remaining retries
|
||||
let mut retries_remaining = self.max_retries();
|
||||
|
||||
// Get timeout if strategy will employ one
|
||||
let timeout = self.timeout();
|
||||
|
||||
// Get maximum allowed duration between attempts
|
||||
let max_duration = self.max_duration();
|
||||
|
||||
// Continue trying to reconnect while we have more tries remaining, otherwise
|
||||
// we will return the last error encountered
|
||||
let mut result = Ok(());
|
||||
|
||||
while retries_remaining.is_none() || retries_remaining > Some(0) {
|
||||
// Perform reconnect attempt
|
||||
result = match timeout {
|
||||
Some(timeout) => {
|
||||
match tokio::time::timeout(timeout, reconnectable.reconnect()).await {
|
||||
Ok(x) => x,
|
||||
Err(x) => Err(x.into()),
|
||||
}
|
||||
}
|
||||
None => reconnectable.reconnect().await,
|
||||
};
|
||||
|
||||
// If reconnect was successful, we're done and we can exit
|
||||
if result.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Decrement remaining retries if we have a limit
|
||||
if let Some(remaining) = retries_remaining.as_mut() {
|
||||
if *remaining > 0 {
|
||||
*remaining -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Sleep before making next attempt
|
||||
tokio::time::sleep(current_sleep).await;
|
||||
|
||||
// Update our sleep duration
|
||||
let next_sleep = self.adjust_sleep(previous_sleep, current_sleep);
|
||||
previous_sleep = Some(current_sleep);
|
||||
current_sleep = if let Some(duration) = max_duration {
|
||||
std::cmp::min(next_sleep, duration)
|
||||
} else {
|
||||
next_sleep
|
||||
};
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Returns true if this strategy is the fail variant.
|
||||
pub fn is_fail(&self) -> bool {
|
||||
matches!(self, Self::Fail)
|
||||
}
|
||||
|
||||
/// Returns true if this strategy is the exponential backoff variant.
|
||||
pub fn is_exponential_backoff(&self) -> bool {
|
||||
matches!(self, Self::ExponentialBackoff { .. })
|
||||
}
|
||||
|
||||
/// Returns true if this strategy is the fibonacci backoff variant.
|
||||
pub fn is_fibonacci_backoff(&self) -> bool {
|
||||
matches!(self, Self::FibonacciBackoff { .. })
|
||||
}
|
||||
|
||||
/// Returns true if this strategy is the fixed interval variant.
|
||||
pub fn is_fixed_interval(&self) -> bool {
|
||||
matches!(self, Self::FixedInterval { .. })
|
||||
}
|
||||
|
||||
/// Returns the maximum duration between reconnect attempts, or None if there is no limit.
|
||||
pub fn max_duration(&self) -> Option<Duration> {
|
||||
match self {
|
||||
ReconnectStrategy::Fail => None,
|
||||
ReconnectStrategy::ExponentialBackoff { max_duration, .. } => *max_duration,
|
||||
ReconnectStrategy::FibonacciBackoff { max_duration, .. } => *max_duration,
|
||||
ReconnectStrategy::FixedInterval { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the maximum reconnect attempts the strategy will perform, or None if will attempt
|
||||
/// forever.
|
||||
pub fn max_retries(&self) -> Option<usize> {
|
||||
match self {
|
||||
ReconnectStrategy::Fail => None,
|
||||
ReconnectStrategy::ExponentialBackoff { max_retries, .. } => *max_retries,
|
||||
ReconnectStrategy::FibonacciBackoff { max_retries, .. } => *max_retries,
|
||||
ReconnectStrategy::FixedInterval { max_retries, .. } => *max_retries,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the timeout per reconnect attempt that is associated with the strategy.
|
||||
pub fn timeout(&self) -> Option<Duration> {
|
||||
match self {
|
||||
ReconnectStrategy::Fail => None,
|
||||
ReconnectStrategy::ExponentialBackoff { timeout, .. } => *timeout,
|
||||
ReconnectStrategy::FibonacciBackoff { timeout, .. } => *timeout,
|
||||
ReconnectStrategy::FixedInterval { timeout, .. } => *timeout,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the initial duration to sleep.
|
||||
fn initial_sleep_duration(&self) -> Duration {
|
||||
match self {
|
||||
ReconnectStrategy::Fail => Duration::new(0, 0),
|
||||
ReconnectStrategy::ExponentialBackoff { base, .. } => *base,
|
||||
ReconnectStrategy::FibonacciBackoff { base, .. } => *base,
|
||||
ReconnectStrategy::FixedInterval { interval, .. } => *interval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adjusts next sleep duration based on the strategy.
|
||||
fn adjust_sleep(&self, prev: Option<Duration>, curr: Duration) -> Duration {
|
||||
match self {
|
||||
ReconnectStrategy::Fail => Duration::new(0, 0),
|
||||
ReconnectStrategy::ExponentialBackoff { factor, .. } => {
|
||||
let next_millis = (curr.as_millis() as f64) * factor;
|
||||
Duration::from_millis(if next_millis > (std::u64::MAX as f64) {
|
||||
std::u64::MAX
|
||||
} else {
|
||||
next_millis as u64
|
||||
})
|
||||
}
|
||||
ReconnectStrategy::FibonacciBackoff { .. } => {
|
||||
let prev = prev.unwrap_or_else(|| Duration::new(0, 0));
|
||||
prev.checked_add(curr).unwrap_or(Duration::MAX)
|
||||
}
|
||||
ReconnectStrategy::FixedInterval { .. } => curr,
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
use async_trait::async_trait;
|
||||
use dyn_clone::DynClone;
|
||||
use std::io;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
/// Interface representing functionality to shut down an active client.
|
||||
#[async_trait]
|
||||
pub trait Shutdown: DynClone + Send + Sync {
|
||||
/// Attempts to shutdown the client.
|
||||
async fn shutdown(&self) -> io::Result<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Shutdown for mpsc::Sender<oneshot::Sender<io::Result<()>>> {
|
||||
async fn shutdown(&self) -> io::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
match self.send(tx).await {
|
||||
Ok(_) => match rx.await {
|
||||
Ok(x) => x,
|
||||
Err(_) => Err(already_shutdown()),
|
||||
},
|
||||
Err(_) => Err(already_shutdown()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn already_shutdown() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, "Client already shutdown")
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn Shutdown> {
|
||||
fn clone(&self) -> Self {
|
||||
dyn_clone::clone_box(&**self)
|
||||
}
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
use bytes::BytesMut;
|
||||
use std::io;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
/// Represents abstraction of a codec that implements specific encoder and decoder for distant
|
||||
pub trait Codec:
|
||||
for<'a> Encoder<&'a [u8], Error = io::Error> + Decoder<Item = Vec<u8>, Error = io::Error> + Clone
|
||||
{
|
||||
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()>;
|
||||
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>>;
|
||||
}
|
||||
|
||||
macro_rules! impl_traits_for_codec {
|
||||
($type:ident) => {
|
||||
impl<'a> tokio_util::codec::Encoder<&'a [u8]> for $type {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: &'a [u8], dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
Codec::encode(self, item, dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio_util::codec::Decoder for $type {
|
||||
type Item = Vec<u8>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
Codec::decode(self, src)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
mod plain;
|
||||
pub use plain::PlainCodec;
|
||||
|
||||
mod xchacha20poly1305;
|
||||
pub use xchacha20poly1305::XChaCha20Poly1305Codec;
|
@ -1,193 +0,0 @@
|
||||
use crate::Codec;
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use std::convert::TryInto;
|
||||
use tokio::io;
|
||||
|
||||
/// Total bytes to use as the len field denoting a frame's size
|
||||
const LEN_SIZE: usize = 8;
|
||||
|
||||
/// Represents a codec that just ships messages back and forth with no encryption or authentication
|
||||
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub struct PlainCodec;
|
||||
impl_traits_for_codec!(PlainCodec);
|
||||
|
||||
impl PlainCodec {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Codec for PlainCodec {
|
||||
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
|
||||
// Validate that we can fit the message plus nonce +
|
||||
if item.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Empty item provided",
|
||||
));
|
||||
}
|
||||
|
||||
dst.reserve(8 + item.len());
|
||||
|
||||
// Add data in form of {LEN}{ITEM}
|
||||
dst.put_u64((item.len()) as u64);
|
||||
dst.put_slice(item);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
|
||||
// First, check if we have more data than just our frame's message length
|
||||
if src.len() <= LEN_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Second, retrieve total size of our frame's message
|
||||
let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap()) as usize;
|
||||
if msg_len == 0 {
|
||||
// Ensure we advance to remove the frame
|
||||
src.advance(LEN_SIZE);
|
||||
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Frame's msg cannot have length of 0",
|
||||
));
|
||||
}
|
||||
|
||||
// Third, check if we have all data for our frame; if not, exit early
|
||||
if src.len() < msg_len + LEN_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Fourth, get and return our item
|
||||
let item = src[LEN_SIZE..(LEN_SIZE + msg_len)].to_vec();
|
||||
|
||||
// Fifth, advance so frame is no longer kept around
|
||||
src.advance(LEN_SIZE + msg_len);
|
||||
|
||||
Ok(Some(item))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encode_should_fail_when_item_is_zero_bytes() {
|
||||
let mut codec = PlainCodec::new();
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
let result = codec.encode(&[], &mut buf);
|
||||
|
||||
match result {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidInput => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_should_build_a_frame_containing_a_length_and_item() {
|
||||
let mut codec = PlainCodec::new();
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
codec
|
||||
.encode(b"hello, world", &mut buf)
|
||||
.expect("Failed to encode");
|
||||
|
||||
let len = buf.get_u64() as usize;
|
||||
assert_eq!(len, 12, "Wrong length encoded");
|
||||
assert_eq!(buf.as_ref(), b"hello, world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_return_none_if_data_smaller_than_or_equal_to_item_length_field() {
|
||||
let mut codec = PlainCodec::new();
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_bytes(0, LEN_SIZE);
|
||||
|
||||
let result = codec.decode(&mut buf);
|
||||
assert!(
|
||||
matches!(result, Ok(None)),
|
||||
"Unexpected result: {:?}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_return_none_if_not_enough_data_for_frame() {
|
||||
let mut codec = PlainCodec::new();
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u64(0);
|
||||
|
||||
let result = codec.decode(&mut buf);
|
||||
assert!(
|
||||
matches!(result, Ok(None)),
|
||||
"Unexpected result: {:?}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_fail_if_encoded_item_length_is_zero() {
|
||||
let mut codec = PlainCodec::new();
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u64(0);
|
||||
buf.put_u8(255);
|
||||
|
||||
let result = codec.decode(&mut buf);
|
||||
match result {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_advance_src_by_frame_size_even_if_item_length_is_zero() {
|
||||
let mut codec = PlainCodec::new();
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u64(0);
|
||||
buf.put_bytes(0, 3);
|
||||
|
||||
assert!(
|
||||
codec.decode(&mut buf).is_err(),
|
||||
"Decode unexpectedly succeeded"
|
||||
);
|
||||
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_advance_src_by_frame_size_when_successful() {
|
||||
let mut codec = PlainCodec::new();
|
||||
|
||||
// Add 3 extra bytes after a full frame
|
||||
let mut buf = BytesMut::new();
|
||||
codec
|
||||
.encode(b"hello, world", &mut buf)
|
||||
.expect("Failed to encode");
|
||||
buf.put_bytes(0, 3);
|
||||
|
||||
assert!(codec.decode(&mut buf).is_ok(), "Decode unexpectedly failed");
|
||||
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_return_some_byte_vec_when_successful() {
|
||||
let mut codec = PlainCodec::new();
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
codec
|
||||
.encode(b"hello, world", &mut buf)
|
||||
.expect("Failed to encode");
|
||||
|
||||
let item = codec
|
||||
.decode(&mut buf)
|
||||
.expect("Failed to decode")
|
||||
.expect("Item not properly captured");
|
||||
assert_eq!(item, b"hello, world");
|
||||
}
|
||||
}
|
@ -1,269 +0,0 @@
|
||||
use crate::{Codec, SecretKey, SecretKey32};
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use chacha20poly1305::{aead::Aead, Key, KeyInit, XChaCha20Poly1305, XNonce};
|
||||
use std::{convert::TryInto, fmt};
|
||||
use tokio::io;
|
||||
|
||||
/// Total bytes to use as the len field denoting a frame's size
|
||||
const LEN_SIZE: usize = 8;
|
||||
|
||||
/// Total bytes to use for nonce
|
||||
const NONCE_SIZE: usize = 24;
|
||||
|
||||
/// Represents the codec to encode & decode data while also encrypting/decrypting it
|
||||
///
|
||||
/// Uses a 32-byte key internally
|
||||
#[derive(Clone)]
|
||||
pub struct XChaCha20Poly1305Codec {
|
||||
cipher: XChaCha20Poly1305,
|
||||
}
|
||||
impl_traits_for_codec!(XChaCha20Poly1305Codec);
|
||||
|
||||
impl XChaCha20Poly1305Codec {
|
||||
pub fn new(key: &[u8]) -> Self {
|
||||
let key = Key::from_slice(key);
|
||||
let cipher = XChaCha20Poly1305::new(key);
|
||||
Self { cipher }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SecretKey32> for XChaCha20Poly1305Codec {
|
||||
/// Create a new XChaCha20Poly1305 codec with a 32-byte key
|
||||
fn from(secret_key: SecretKey32) -> Self {
|
||||
Self::new(secret_key.unprotected_as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for XChaCha20Poly1305Codec {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("XChaCha20Poly1305Codec")
|
||||
.field("cipher", &"**OMITTED**".to_string())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Codec for XChaCha20Poly1305Codec {
|
||||
fn encode(&mut self, item: &[u8], dst: &mut BytesMut) -> io::Result<()> {
|
||||
// Validate that we can fit the message plus nonce +
|
||||
if item.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Empty item provided",
|
||||
));
|
||||
}
|
||||
// NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of
|
||||
// maintaining a stateful counter due to its size (24-byte secret key generation
|
||||
// will never panic)
|
||||
let nonce_key = SecretKey::<NONCE_SIZE>::generate().unwrap();
|
||||
let nonce = XNonce::from_slice(nonce_key.unprotected_as_bytes());
|
||||
|
||||
let ciphertext = self
|
||||
.cipher
|
||||
.encrypt(nonce, item)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
|
||||
|
||||
dst.reserve(8 + nonce.len() + ciphertext.len());
|
||||
|
||||
// Add data in form of {LEN}{NONCE}{CIPHER TEXT}
|
||||
dst.put_u64((nonce_key.len() + ciphertext.len()) as u64);
|
||||
dst.put_slice(nonce.as_slice());
|
||||
dst.extend(ciphertext);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
|
||||
// First, check if we have more data than just our frame's message length
|
||||
if src.len() <= LEN_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Second, retrieve total size of our frame's message
|
||||
let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap()) as usize;
|
||||
if msg_len <= NONCE_SIZE {
|
||||
// Ensure we advance to remove the frame
|
||||
src.advance(LEN_SIZE + msg_len);
|
||||
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Frame's msg cannot have length less than 25",
|
||||
));
|
||||
}
|
||||
|
||||
// Third, check if we have all data for our frame; if not, exit early
|
||||
if src.len() < msg_len + LEN_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Fourth, retrieve the nonce used with the ciphertext
|
||||
let nonce = XNonce::from_slice(&src[LEN_SIZE..(NONCE_SIZE + LEN_SIZE)]);
|
||||
|
||||
// Fifth, acquire the encrypted & signed ciphertext
|
||||
let ciphertext = &src[(NONCE_SIZE + LEN_SIZE)..(msg_len + LEN_SIZE)];
|
||||
|
||||
// Sixth, convert ciphertext back into our item
|
||||
let item = self.cipher.decrypt(nonce, ciphertext);
|
||||
|
||||
// Seventh, advance so frame is no longer kept around
|
||||
src.advance(LEN_SIZE + msg_len);
|
||||
|
||||
// Eighth, report an error if there is one
|
||||
let item =
|
||||
item.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?;
|
||||
|
||||
Ok(Some(item))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encode_should_fail_when_item_is_zero_bytes() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
let result = codec.encode(&[], &mut buf);
|
||||
|
||||
match result {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidInput => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
codec
|
||||
.encode(b"hello, world", &mut buf)
|
||||
.expect("Failed to encode");
|
||||
|
||||
let len = buf.get_u64() as usize;
|
||||
assert!(buf.len() > NONCE_SIZE, "Msg size not big enough");
|
||||
assert_eq!(len, buf.len(), "Msg size does not match attached size");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_return_none_if_data_smaller_than_or_equal_to_frame_length_field() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_bytes(0, LEN_SIZE);
|
||||
|
||||
let result = codec.decode(&mut buf);
|
||||
assert!(
|
||||
matches!(result, Ok(None)),
|
||||
"Unexpected result: {:?}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_return_none_if_not_enough_data_for_frame() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u64(0);
|
||||
|
||||
let result = codec.decode(&mut buf);
|
||||
assert!(
|
||||
matches!(result, Ok(None)),
|
||||
"Unexpected result: {:?}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_fail_if_encoded_frame_length_is_smaller_than_nonce_plus_data() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
// NONCE_SIZE + 1 is minimum for frame length
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u64(NONCE_SIZE as u64);
|
||||
buf.put_bytes(0, NONCE_SIZE);
|
||||
|
||||
let result = codec.decode(&mut buf);
|
||||
match result {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_advance_src_by_frame_size_even_if_frame_length_is_too_small() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
// LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u64(NONCE_SIZE as u64);
|
||||
buf.put_bytes(0, NONCE_SIZE);
|
||||
buf.put_bytes(0, 3);
|
||||
|
||||
assert!(
|
||||
codec.decode(&mut buf).is_err(),
|
||||
"Decode unexpectedly succeeded"
|
||||
);
|
||||
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_advance_src_by_frame_size_even_if_decryption_fails() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
// LEN_SIZE + NONCE_SIZE + msg not matching encryption + 3 more bytes
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u64((NONCE_SIZE + 12) as u64);
|
||||
buf.put_bytes(0, NONCE_SIZE);
|
||||
buf.put_slice(b"hello, world");
|
||||
buf.put_bytes(0, 3);
|
||||
|
||||
assert!(
|
||||
codec.decode(&mut buf).is_err(),
|
||||
"Decode unexpectedly succeeded"
|
||||
);
|
||||
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_advance_src_by_frame_size_when_successful() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
// Add 3 extra bytes after a full frame
|
||||
let mut buf = BytesMut::new();
|
||||
codec
|
||||
.encode(b"hello, world", &mut buf)
|
||||
.expect("Failed to encode");
|
||||
buf.put_bytes(0, 3);
|
||||
|
||||
assert!(codec.decode(&mut buf).is_ok(), "Decode unexpectedly failed");
|
||||
assert_eq!(buf.len(), 3, "Advanced an unexpected amount in src buf");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_should_return_some_byte_vec_when_successful() {
|
||||
let key = SecretKey32::default();
|
||||
let mut codec = XChaCha20Poly1305Codec::from(key);
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
codec
|
||||
.encode(b"hello, world", &mut buf)
|
||||
.expect("Failed to encode");
|
||||
|
||||
let item = codec
|
||||
.decode(&mut buf)
|
||||
.expect("Failed to decode")
|
||||
.expect("Item not properly captured");
|
||||
assert_eq!(item, b"hello, world");
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
mod any;
|
||||
pub mod authentication;
|
||||
mod connection;
|
||||
mod destination;
|
||||
mod listener;
|
||||
mod map;
|
||||
mod packet;
|
||||
mod port;
|
||||
mod transport;
|
||||
pub(crate) mod utils;
|
||||
|
||||
pub use any::*;
|
||||
pub(crate) use connection::Connection;
|
||||
pub use connection::ConnectionId;
|
||||
pub use destination::*;
|
||||
pub use listener::*;
|
||||
pub use map::*;
|
||||
pub use packet::*;
|
||||
pub use port::*;
|
||||
pub use transport::*;
|
@ -0,0 +1,10 @@
|
||||
mod authenticator;
|
||||
mod handler;
|
||||
mod keychain;
|
||||
mod methods;
|
||||
pub mod msg;
|
||||
|
||||
pub use authenticator::*;
|
||||
pub use handler::*;
|
||||
pub use keychain::*;
|
||||
pub use methods::*;
|
@ -0,0 +1,672 @@
|
||||
use super::{msg::*, AuthHandler};
|
||||
use crate::common::{utils, FramedTransport, Transport};
|
||||
use async_trait::async_trait;
|
||||
use log::*;
|
||||
use std::io;
|
||||
|
||||
/// Represents an interface for authenticating with a server.
|
||||
#[async_trait]
|
||||
pub trait Authenticate {
|
||||
/// Performs authentication by leveraging the `handler` for any received challenge.
|
||||
async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// Represents an interface for submitting challenges for authentication.
|
||||
#[async_trait]
|
||||
pub trait Authenticator: Send {
|
||||
/// Issues an initialization notice and returns the response indicating which authentication
|
||||
/// methods to pursue
|
||||
async fn initialize(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse>;
|
||||
|
||||
/// Issues a challenge and returns the answers to the `questions` asked.
|
||||
async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse>;
|
||||
|
||||
/// Requests verification of some `kind` and `text`, returning true if passed verification.
|
||||
async fn verify(&mut self, verification: Verification) -> io::Result<VerificationResponse>;
|
||||
|
||||
/// Reports information with no response expected.
|
||||
async fn info(&mut self, info: Info) -> io::Result<()>;
|
||||
|
||||
/// Reports an error occurred during authentication, consuming the authenticator since no more
|
||||
/// challenges should be issued.
|
||||
async fn error(&mut self, error: Error) -> io::Result<()>;
|
||||
|
||||
/// Reports that the authentication has started for a specific method.
|
||||
async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()>;
|
||||
|
||||
/// Reports that the authentication has finished successfully, consuming the authenticator
|
||||
/// since no more challenges should be issued.
|
||||
async fn finished(&mut self) -> io::Result<()>;
|
||||
}
|
||||
|
||||
macro_rules! write_frame {
|
||||
($transport:expr, $data:expr) => {{
|
||||
let data = utils::serialize_to_vec(&$data)?;
|
||||
if log_enabled!(Level::Trace) {
|
||||
trace!("Writing data as frame: {data:?}");
|
||||
}
|
||||
|
||||
$transport.write_frame(data).await?
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! next_frame_as {
|
||||
($transport:expr, $type:ident, $variant:ident) => {{
|
||||
match { next_frame_as!($transport, $type) } {
|
||||
$type::$variant(x) => x,
|
||||
x => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Unexpected frame: {x:?}"),
|
||||
))
|
||||
}
|
||||
}
|
||||
}};
|
||||
($transport:expr, $type:ident) => {{
|
||||
let frame = $transport.read_frame().await?.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
concat!(
|
||||
"Transport closed early waiting for frame of type ",
|
||||
stringify!($type),
|
||||
),
|
||||
)
|
||||
})?;
|
||||
|
||||
match utils::deserialize_from_slice::<$type>(frame.as_item()) {
|
||||
Ok(frame) => frame,
|
||||
Err(x) => {
|
||||
if log_enabled!(Level::Trace) {
|
||||
trace!(
|
||||
"Failed to deserialize frame item as {}: {:?}",
|
||||
stringify!($type),
|
||||
frame.as_item()
|
||||
);
|
||||
}
|
||||
|
||||
Err(x)?;
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> Authenticate for FramedTransport<T>
|
||||
where
|
||||
T: Transport,
|
||||
{
|
||||
async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()> {
|
||||
loop {
|
||||
trace!("Authenticate::authenticate waiting on next authentication frame");
|
||||
match next_frame_as!(self, Authentication) {
|
||||
Authentication::Initialization(x) => {
|
||||
trace!("Authenticate::Initialization({x:?})");
|
||||
let response = handler.on_initialization(x).await?;
|
||||
write_frame!(self, AuthenticationResponse::Initialization(response));
|
||||
}
|
||||
Authentication::Challenge(x) => {
|
||||
trace!("Authenticate::Challenge({x:?})");
|
||||
let response = handler.on_challenge(x).await?;
|
||||
write_frame!(self, AuthenticationResponse::Challenge(response));
|
||||
}
|
||||
Authentication::Verification(x) => {
|
||||
trace!("Authenticate::Verify({x:?})");
|
||||
let response = handler.on_verification(x).await?;
|
||||
write_frame!(self, AuthenticationResponse::Verification(response));
|
||||
}
|
||||
Authentication::Info(x) => {
|
||||
trace!("Authenticate::Info({x:?})");
|
||||
handler.on_info(x).await?;
|
||||
}
|
||||
Authentication::Error(x) => {
|
||||
trace!("Authenticate::Error({x:?})");
|
||||
handler.on_error(x.clone()).await?;
|
||||
|
||||
if x.is_fatal() {
|
||||
return Err(x.into_io_permission_denied());
|
||||
}
|
||||
}
|
||||
Authentication::StartMethod(x) => {
|
||||
trace!("Authenticate::StartMethod({x:?})");
|
||||
handler.on_start_method(x).await?;
|
||||
}
|
||||
Authentication::Finished => {
|
||||
trace!("Authenticate::Finished");
|
||||
handler.on_finished().await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> Authenticator for FramedTransport<T>
|
||||
where
|
||||
T: Transport,
|
||||
{
|
||||
async fn initialize(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
trace!("Authenticator::initialize({initialization:?})");
|
||||
write_frame!(self, Authentication::Initialization(initialization));
|
||||
let response = next_frame_as!(self, AuthenticationResponse, Initialization);
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
trace!("Authenticator::challenge({challenge:?})");
|
||||
write_frame!(self, Authentication::Challenge(challenge));
|
||||
let response = next_frame_as!(self, AuthenticationResponse, Challenge);
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn verify(&mut self, verification: Verification) -> io::Result<VerificationResponse> {
|
||||
trace!("Authenticator::verify({verification:?})");
|
||||
write_frame!(self, Authentication::Verification(verification));
|
||||
let response = next_frame_as!(self, AuthenticationResponse, Verification);
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn info(&mut self, info: Info) -> io::Result<()> {
|
||||
trace!("Authenticator::info({info:?})");
|
||||
write_frame!(self, Authentication::Info(info));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn error(&mut self, error: Error) -> io::Result<()> {
|
||||
trace!("Authenticator::error({error:?})");
|
||||
write_frame!(self, Authentication::Error(error));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
trace!("Authenticator::start_method({start_method:?})");
|
||||
write_frame!(self, Authentication::StartMethod(start_method));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn finished(&mut self) -> io::Result<()> {
|
||||
trace!("Authenticator::finished()");
|
||||
write_frame!(self, Authentication::Finished);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::common::authentication::AuthMethodHandler;
|
||||
use test_log::test;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[async_trait]
|
||||
trait TestAuthHandler {
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
_: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_start_method(&mut self, _: StartMethod) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_verification(&mut self, _: Verification) -> io::Result<VerificationResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, _: Info) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, _: Error) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: TestAuthHandler + Send> AuthHandler for T {
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
x: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
TestAuthHandler::on_initialization(self, x).await
|
||||
}
|
||||
|
||||
async fn on_start_method(&mut self, x: StartMethod) -> io::Result<()> {
|
||||
TestAuthHandler::on_start_method(self, x).await
|
||||
}
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
TestAuthHandler::on_finished(self).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: TestAuthHandler + Send> AuthMethodHandler for T {
|
||||
async fn on_challenge(&mut self, x: Challenge) -> io::Result<ChallengeResponse> {
|
||||
TestAuthHandler::on_challenge(self, x).await
|
||||
}
|
||||
|
||||
async fn on_verification(&mut self, x: Verification) -> io::Result<VerificationResponse> {
|
||||
TestAuthHandler::on_verification(self, x).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, x: Info) -> io::Result<()> {
|
||||
TestAuthHandler::on_info(self, x).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, x: Error) -> io::Result<()> {
|
||||
TestAuthHandler::on_error(self, x).await
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! auth_handler {
|
||||
(@no_challenge @no_verification @tx($tx:ident, $ty:ty) $($methods:item)*) => {
|
||||
auth_handler! {
|
||||
@tx($tx, $ty)
|
||||
|
||||
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
_: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
$($methods)*
|
||||
}
|
||||
};
|
||||
(@no_challenge @tx($tx:ident, $ty:ty) $($methods:item)*) => {
|
||||
auth_handler! {
|
||||
@tx($tx, $ty)
|
||||
|
||||
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
$($methods)*
|
||||
}
|
||||
};
|
||||
(@no_verification @tx($tx:ident, $ty:ty) $($methods:item)*) => {
|
||||
auth_handler! {
|
||||
@tx($tx, $ty)
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
_: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
$($methods)*
|
||||
}
|
||||
};
|
||||
(@tx($tx:ident, $ty:ty) $($methods:item)*) => {{
|
||||
#[allow(dead_code)]
|
||||
struct __InlineAuthHandler {
|
||||
tx: mpsc::Sender<$ty>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TestAuthHandler for __InlineAuthHandler {
|
||||
$($methods)*
|
||||
}
|
||||
|
||||
__InlineAuthHandler { tx: $tx }
|
||||
}};
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticator_initialization_should_be_able_to_successfully_complete_round_trip() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
t2.authenticate(auth_handler! {
|
||||
@no_challenge
|
||||
@no_verification
|
||||
@tx(tx, ())
|
||||
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
Ok(InitializationResponse {
|
||||
methods: initialization.methods,
|
||||
})
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let response = t1
|
||||
.initialize(Initialization {
|
||||
methods: vec!["test method".to_string()].into_iter().collect(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!task.is_finished(),
|
||||
"Auth handler unexpectedly finished without signal"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
InitializationResponse {
|
||||
methods: vec!["test method".to_string()].into_iter().collect()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticator_challenge_should_be_able_to_successfully_complete_round_trip() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
t2.authenticate(auth_handler! {
|
||||
@no_verification
|
||||
@tx(tx, ())
|
||||
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
assert_eq!(challenge.questions, vec![Question {
|
||||
label: "label".to_string(),
|
||||
text: "text".to_string(),
|
||||
options: vec![("question_key".to_string(), "question_value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
}]);
|
||||
assert_eq!(
|
||||
challenge.options,
|
||||
vec![("key".to_string(), "value".to_string())].into_iter().collect(),
|
||||
);
|
||||
Ok(ChallengeResponse {
|
||||
answers: vec!["some answer".to_string()].into_iter().collect(),
|
||||
})
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let response = t1
|
||||
.challenge(Challenge {
|
||||
questions: vec![Question {
|
||||
label: "label".to_string(),
|
||||
text: "text".to_string(),
|
||||
options: vec![("question_key".to_string(), "question_value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
}],
|
||||
options: vec![("key".to_string(), "value".to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!task.is_finished(),
|
||||
"Auth handler unexpectedly finished without signal"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
ChallengeResponse {
|
||||
answers: vec!["some answer".to_string()],
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticator_verification_should_be_able_to_successfully_complete_round_trip() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
t2.authenticate(auth_handler! {
|
||||
@no_challenge
|
||||
@tx(tx, ())
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
assert_eq!(verification.kind, VerificationKind::Host);
|
||||
assert_eq!(verification.text, "some text");
|
||||
Ok(VerificationResponse {
|
||||
valid: true,
|
||||
})
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let response = t1
|
||||
.verify(Verification {
|
||||
kind: VerificationKind::Host,
|
||||
text: "some text".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!task.is_finished(),
|
||||
"Auth handler unexpectedly finished without signal"
|
||||
);
|
||||
|
||||
assert_eq!(response, VerificationResponse { valid: true });
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticator_info_should_be_able_to_be_sent_to_auth_handler() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
t2.authenticate(auth_handler! {
|
||||
@no_challenge
|
||||
@no_verification
|
||||
@tx(tx, Info)
|
||||
|
||||
async fn on_info(
|
||||
&mut self,
|
||||
info: Info,
|
||||
) -> io::Result<()> {
|
||||
self.tx.send(info).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
t1.info(Info {
|
||||
text: "some text".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
rx.recv().await.unwrap(),
|
||||
Info {
|
||||
text: "some text".to_string()
|
||||
}
|
||||
);
|
||||
|
||||
assert!(
|
||||
!task.is_finished(),
|
||||
"Auth handler unexpectedly finished without signal"
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticator_error_should_be_able_to_be_sent_to_auth_handler() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
t2.authenticate(auth_handler! {
|
||||
@no_challenge
|
||||
@no_verification
|
||||
@tx(tx, Error)
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
self.tx.send(error).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
t1.error(Error {
|
||||
kind: ErrorKind::Error,
|
||||
text: "some text".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
rx.recv().await.unwrap(),
|
||||
Error {
|
||||
kind: ErrorKind::Error,
|
||||
text: "some text".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
assert!(
|
||||
!task.is_finished(),
|
||||
"Auth handler unexpectedly finished without signal"
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn auth_handler_received_error_should_fail_auth_handler_if_fatal() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
t2.authenticate(auth_handler! {
|
||||
@no_challenge
|
||||
@no_verification
|
||||
@tx(tx, Error)
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
self.tx.send(error).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
t1.error(Error {
|
||||
kind: ErrorKind::Fatal,
|
||||
text: "some text".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
rx.recv().await.unwrap(),
|
||||
Error {
|
||||
kind: ErrorKind::Fatal,
|
||||
text: "some text".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
// Verify that the handler exited with an error
|
||||
task.await.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticator_start_method_should_be_able_to_be_sent_to_auth_handler() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
t2.authenticate(auth_handler! {
|
||||
@no_challenge
|
||||
@no_verification
|
||||
@tx(tx, StartMethod)
|
||||
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
self.tx.send(start_method).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
t1.start_method(StartMethod {
|
||||
method: "some method".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
rx.recv().await.unwrap(),
|
||||
StartMethod {
|
||||
method: "some method".to_string()
|
||||
}
|
||||
);
|
||||
|
||||
assert!(
|
||||
!task.is_finished(),
|
||||
"Auth handler unexpectedly finished without signal"
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticator_finished_should_be_able_to_be_sent_to_auth_handler() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
t2.authenticate(auth_handler! {
|
||||
@no_challenge
|
||||
@no_verification
|
||||
@tx(tx, ())
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
self.tx.send(()).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
t1.finished().await.unwrap();
|
||||
|
||||
// Verify that the callback was triggered
|
||||
rx.recv().await.unwrap();
|
||||
|
||||
// Finished should signal that the handler completed successfully
|
||||
task.await.unwrap();
|
||||
}
|
||||
}
|
@ -0,0 +1,343 @@
|
||||
use super::msg::*;
|
||||
use crate::common::authentication::Authenticator;
|
||||
use crate::common::HeapSecretKey;
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
|
||||
mod methods;
|
||||
pub use methods::*;
|
||||
|
||||
/// Interface for a handler of authentication requests for all methods.
|
||||
#[async_trait]
|
||||
pub trait AuthHandler: AuthMethodHandler + Send {
|
||||
/// Callback when authentication is beginning, providing available authentication methods and
|
||||
/// returning selected authentication methods to pursue.
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
Ok(InitializationResponse {
|
||||
methods: initialization.methods,
|
||||
})
|
||||
}
|
||||
|
||||
/// Callback when authentication starts for a specific method.
|
||||
#[allow(unused_variables)]
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Callback when authentication is finished and no more requests will be received.
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Dummy implementation of [`AuthHandler`] where any challenge or verification request will
|
||||
/// instantly fail.
|
||||
pub struct DummyAuthHandler;
|
||||
|
||||
#[async_trait]
|
||||
impl AuthHandler for DummyAuthHandler {}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for DummyAuthHandler {
|
||||
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_verification(&mut self, _: Verification) -> io::Result<VerificationResponse> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, _: Info) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, _: Error) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Unsupported))
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`AuthHandler`] that uses the same [`AuthMethodHandler`] for all methods.
|
||||
pub struct SingleAuthHandler(Box<dyn AuthMethodHandler>);
|
||||
|
||||
impl SingleAuthHandler {
|
||||
pub fn new<T: AuthMethodHandler + 'static>(method_handler: T) -> Self {
|
||||
Self(Box::new(method_handler))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthHandler for SingleAuthHandler {}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for SingleAuthHandler {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
self.0.on_challenge(challenge).await
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
self.0.on_verification(verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
self.0.on_info(info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
self.0.on_error(error).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`AuthHandler`] that maintains a map of [`AuthMethodHandler`] implementations
|
||||
/// for specific methods, invoking [`on_challenge`], [`on_verification`], [`on_info`], and
|
||||
/// [`on_error`] for a specific handler based on an associated id.
|
||||
///
|
||||
/// [`on_challenge`]: AuthMethodHandler::on_challenge
|
||||
/// [`on_verification`]: AuthMethodHandler::on_verification
|
||||
/// [`on_info`]: AuthMethodHandler::on_info
|
||||
/// [`on_error`]: AuthMethodHandler::on_error
|
||||
pub struct AuthHandlerMap {
|
||||
active: String,
|
||||
map: HashMap<&'static str, Box<dyn AuthMethodHandler>>,
|
||||
}
|
||||
|
||||
impl AuthHandlerMap {
|
||||
/// Creates a new, empty map of auth method handlers.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
active: String::new(),
|
||||
map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `id` of the active [`AuthMethodHandler`].
|
||||
pub fn active_id(&self) -> &str {
|
||||
&self.active
|
||||
}
|
||||
|
||||
/// Sets the active [`AuthMethodHandler`] by its `id`.
|
||||
pub fn set_active_id(&mut self, id: impl Into<String>) {
|
||||
self.active = id.into();
|
||||
}
|
||||
|
||||
/// Inserts the specified `handler` into the map, associating it with `id` for determining the
|
||||
/// method that would trigger this handler.
|
||||
pub fn insert_method_handler<T: AuthMethodHandler + 'static>(
|
||||
&mut self,
|
||||
id: &'static str,
|
||||
handler: T,
|
||||
) -> Option<Box<dyn AuthMethodHandler>> {
|
||||
self.map.insert(id, Box::new(handler))
|
||||
}
|
||||
|
||||
/// Removes a handler with the associated `id`.
|
||||
pub fn remove_method_handler(
|
||||
&mut self,
|
||||
id: &'static str,
|
||||
) -> Option<Box<dyn AuthMethodHandler>> {
|
||||
self.map.remove(id)
|
||||
}
|
||||
|
||||
/// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`,
|
||||
/// returning an error if no handler for the active id is found.
|
||||
pub fn get_mut_active_method_handler_or_error(
|
||||
&mut self,
|
||||
) -> io::Result<&mut (dyn AuthMethodHandler + 'static)> {
|
||||
let id = self.active.clone();
|
||||
self.get_mut_active_method_handler().ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::Other, format!("No active handler for {id}"))
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`.
|
||||
pub fn get_mut_active_method_handler(
|
||||
&mut self,
|
||||
) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
|
||||
// TODO: Optimize this
|
||||
self.get_mut_method_handler(&self.active.clone())
|
||||
}
|
||||
|
||||
/// Retrieves a mutable reference to the [`AuthMethodHandler`] with the specified `id`.
|
||||
pub fn get_mut_method_handler(
|
||||
&mut self,
|
||||
id: &str,
|
||||
) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
|
||||
self.map.get_mut(id).map(|h| h.as_mut())
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthHandlerMap {
|
||||
/// Consumes the map, returning a new map that supports the `static_key` method.
|
||||
pub fn with_static_key(mut self, key: impl Into<HeapSecretKey>) -> Self {
|
||||
self.insert_method_handler("static_key", StaticKeyAuthMethodHandler::simple(key));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AuthHandlerMap {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthHandler for AuthHandlerMap {
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
let methods = initialization
|
||||
.methods
|
||||
.into_iter()
|
||||
.filter(|method| self.map.contains_key(method.as_str()))
|
||||
.collect();
|
||||
|
||||
Ok(InitializationResponse { methods })
|
||||
}
|
||||
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
self.set_active_id(start_method.method);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for AuthHandlerMap {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
let handler = self.get_mut_active_method_handler_or_error()?;
|
||||
handler.on_challenge(challenge).await
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
let handler = self.get_mut_active_method_handler_or_error()?;
|
||||
handler.on_verification(verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
let handler = self.get_mut_active_method_handler_or_error()?;
|
||||
handler.on_info(info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
let handler = self.get_mut_active_method_handler_or_error()?;
|
||||
handler.on_error(error).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`AuthHandler`] that redirects all requests to an [`Authenticator`].
|
||||
pub struct ProxyAuthHandler<'a>(&'a mut dyn Authenticator);
|
||||
|
||||
impl<'a> ProxyAuthHandler<'a> {
|
||||
pub fn new(authenticator: &'a mut dyn Authenticator) -> Self {
|
||||
Self(authenticator)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> AuthHandler for ProxyAuthHandler<'a> {
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
Authenticator::initialize(self.0, initialization).await
|
||||
}
|
||||
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
Authenticator::start_method(self.0, start_method).await
|
||||
}
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
Authenticator::finished(self.0).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> AuthMethodHandler for ProxyAuthHandler<'a> {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
Authenticator::challenge(self.0, challenge).await
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
Authenticator::verify(self.0, verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
Authenticator::info(self.0, info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
Authenticator::error(self.0, error).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`AuthHandler`] that holds a mutable reference to another [`AuthHandler`]
|
||||
/// trait object to use underneath.
|
||||
pub struct DynAuthHandler<'a>(&'a mut dyn AuthHandler);
|
||||
|
||||
impl<'a> DynAuthHandler<'a> {
|
||||
pub fn new(handler: &'a mut dyn AuthHandler) -> Self {
|
||||
Self(handler)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: AuthHandler> From<&'a mut T> for DynAuthHandler<'a> {
|
||||
fn from(handler: &'a mut T) -> Self {
|
||||
Self::new(handler as &mut dyn AuthHandler)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> AuthHandler for DynAuthHandler<'a> {
|
||||
async fn on_initialization(
|
||||
&mut self,
|
||||
initialization: Initialization,
|
||||
) -> io::Result<InitializationResponse> {
|
||||
self.0.on_initialization(initialization).await
|
||||
}
|
||||
|
||||
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
|
||||
self.0.on_start_method(start_method).await
|
||||
}
|
||||
|
||||
async fn on_finished(&mut self) -> io::Result<()> {
|
||||
self.0.on_finished().await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> AuthMethodHandler for DynAuthHandler<'a> {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
self.0.on_challenge(challenge).await
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
self.0.on_verification(verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
self.0.on_info(info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
self.0.on_error(error).await
|
||||
}
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
use super::{
|
||||
Challenge, ChallengeResponse, Error, Info, Verification, VerificationKind, VerificationResponse,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use std::io;
|
||||
|
||||
/// Interface for a handler of authentication requests for a specific authentication method.
|
||||
#[async_trait]
|
||||
pub trait AuthMethodHandler: Send {
|
||||
/// Callback when a challenge is received, returning answers to the given questions.
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse>;
|
||||
|
||||
/// Callback when a verification request is received, returning true if approvided or false if
|
||||
/// unapproved.
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse>;
|
||||
|
||||
/// Callback when information is received. To fail, return an error from this function.
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()>;
|
||||
|
||||
/// Callback when an error is received. Regardless of the result returned, this will terminate
|
||||
/// the authenticator. In the situation where a custom error would be preferred, have this
|
||||
/// callback return an error.
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()>;
|
||||
}
|
||||
|
||||
mod prompt;
|
||||
pub use prompt::*;
|
||||
|
||||
mod static_key;
|
||||
pub use static_key::*;
|
@ -0,0 +1,88 @@
|
||||
use super::{
|
||||
AuthMethodHandler, Challenge, ChallengeResponse, Error, Info, Verification, VerificationKind,
|
||||
VerificationResponse,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use log::*;
|
||||
use std::io;
|
||||
|
||||
/// Blocking implementation of [`AuthMethodHandler`] that uses prompts to communicate challenge &
|
||||
/// verification requests, receiving responses to relay back.
|
||||
pub struct PromptAuthMethodHandler<T, U> {
|
||||
text_prompt: T,
|
||||
password_prompt: U,
|
||||
}
|
||||
|
||||
impl<T, U> PromptAuthMethodHandler<T, U> {
|
||||
pub fn new(text_prompt: T, password_prompt: U) -> Self {
|
||||
Self {
|
||||
text_prompt,
|
||||
password_prompt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, U> AuthMethodHandler for PromptAuthMethodHandler<T, U>
|
||||
where
|
||||
T: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
U: Fn(&str) -> io::Result<String> + Send + Sync + 'static,
|
||||
{
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
trace!("on_challenge({challenge:?})");
|
||||
let mut answers = Vec::new();
|
||||
for question in challenge.questions.iter() {
|
||||
// Contains all prompt lines including same line
|
||||
let mut lines = question.text.split('\n').collect::<Vec<_>>();
|
||||
|
||||
// Line that is prompt on same line as answer
|
||||
let line = lines.pop().unwrap();
|
||||
|
||||
// Go ahead and display all other lines
|
||||
for line in lines.into_iter() {
|
||||
eprintln!("{}", line);
|
||||
}
|
||||
|
||||
// Get an answer from user input, or use a blank string as an answer
|
||||
// if we fail to get input from the user
|
||||
let answer = (self.password_prompt)(line).unwrap_or_default();
|
||||
|
||||
answers.push(answer);
|
||||
}
|
||||
Ok(ChallengeResponse { answers })
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
trace!("on_verify({verification:?})");
|
||||
match verification.kind {
|
||||
VerificationKind::Host => {
|
||||
eprintln!("{}", verification.text);
|
||||
|
||||
let answer = (self.text_prompt)("Enter [y/N]> ")?;
|
||||
trace!("Verify? Answer = '{answer}'");
|
||||
Ok(VerificationResponse {
|
||||
valid: matches!(answer.trim(), "y" | "Y" | "yes" | "YES"),
|
||||
})
|
||||
}
|
||||
x => {
|
||||
error!("Unsupported verify kind: {x}");
|
||||
Ok(VerificationResponse { valid: false })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
trace!("on_info({info:?})");
|
||||
println!("{}", info.text);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
trace!("on_error({error:?})");
|
||||
eprintln!("{}: {}", error.kind, error.text);
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -0,0 +1,171 @@
|
||||
use super::{
|
||||
AuthMethodHandler, Challenge, ChallengeResponse, Error, Info, Verification,
|
||||
VerificationResponse,
|
||||
};
|
||||
use crate::common::HeapSecretKey;
|
||||
use async_trait::async_trait;
|
||||
use log::*;
|
||||
use std::io;
|
||||
|
||||
/// Implementation of [`AuthMethodHandler`] that answers challenge requests using a static
|
||||
/// [`HeapSecretKey`]. All other portions of method authentication are handled by another
|
||||
/// [`AuthMethodHandler`].
|
||||
pub struct StaticKeyAuthMethodHandler {
|
||||
key: HeapSecretKey,
|
||||
handler: Box<dyn AuthMethodHandler>,
|
||||
}
|
||||
|
||||
impl StaticKeyAuthMethodHandler {
|
||||
/// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static
|
||||
/// `key`. All other requests are passed to the `handler`.
|
||||
pub fn new<T: AuthMethodHandler + 'static>(key: impl Into<HeapSecretKey>, handler: T) -> Self {
|
||||
Self {
|
||||
key: key.into(),
|
||||
handler: Box::new(handler),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static
|
||||
/// `key`. All other requests are passed automatically, meaning that verification is always
|
||||
/// approvide and info/errors are ignored.
|
||||
pub fn simple(key: impl Into<HeapSecretKey>) -> Self {
|
||||
Self::new(key, {
|
||||
struct __AuthMethodHandler;
|
||||
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for __AuthMethodHandler {
|
||||
async fn on_challenge(&mut self, _: Challenge) -> io::Result<ChallengeResponse> {
|
||||
unreachable!("on_challenge should be handled by StaticKeyAuthMethodHandler");
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
_: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
Ok(VerificationResponse { valid: true })
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, _: Info) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, _: Error) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
__AuthMethodHandler
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthMethodHandler for StaticKeyAuthMethodHandler {
|
||||
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
|
||||
trace!("on_challenge({challenge:?})");
|
||||
let mut answers = Vec::new();
|
||||
for question in challenge.questions.iter() {
|
||||
// Only challenges with a "key" label are allowed, all else will fail
|
||||
if question.label != "key" {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Only 'key' challenges are supported",
|
||||
));
|
||||
}
|
||||
answers.push(self.key.to_string());
|
||||
}
|
||||
Ok(ChallengeResponse { answers })
|
||||
}
|
||||
|
||||
async fn on_verification(
|
||||
&mut self,
|
||||
verification: Verification,
|
||||
) -> io::Result<VerificationResponse> {
|
||||
trace!("on_verify({verification:?})");
|
||||
self.handler.on_verification(verification).await
|
||||
}
|
||||
|
||||
async fn on_info(&mut self, info: Info) -> io::Result<()> {
|
||||
trace!("on_info({info:?})");
|
||||
self.handler.on_info(info).await
|
||||
}
|
||||
|
||||
async fn on_error(&mut self, error: Error) -> io::Result<()> {
|
||||
trace!("on_error({error:?})");
|
||||
self.handler.on_error(error).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::common::authentication::msg::{ErrorKind, Question, VerificationKind};
|
||||
use test_log::test;
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_challenge_should_fail_if_non_key_question_received() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
|
||||
|
||||
handler
|
||||
.on_challenge(Challenge {
|
||||
questions: vec![Question::new("test")],
|
||||
options: Default::default(),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_challenge_should_answer_with_stringified_key_for_key_questions() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
|
||||
|
||||
let response = handler
|
||||
.on_challenge(Challenge {
|
||||
questions: vec![Question::new("key")],
|
||||
options: Default::default(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(response.answers.len(), 1, "Wrong answer set received");
|
||||
assert!(!response.answers[0].is_empty(), "Empty answer being sent");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_verification_should_leverage_fallback_handler() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
|
||||
|
||||
let response = handler
|
||||
.on_verification(Verification {
|
||||
kind: VerificationKind::Host,
|
||||
text: "host".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(response.valid, "Unexpected result from fallback handler");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_info_should_leverage_fallback_handler() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
|
||||
|
||||
handler
|
||||
.on_info(Info {
|
||||
text: "info".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn on_error_should_leverage_fallback_handler() {
|
||||
let mut handler = StaticKeyAuthMethodHandler::simple(HeapSecretKey::generate(32).unwrap());
|
||||
|
||||
handler
|
||||
.on_error(Error {
|
||||
kind: ErrorKind::Error,
|
||||
text: "text".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
@ -0,0 +1,156 @@
|
||||
use crate::common::HeapSecretKey;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Represents the result of a request to the database.
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum KeychainResult<T> {
|
||||
/// Id was not found in the database.
|
||||
InvalidId,
|
||||
|
||||
/// Password match for an id failed.
|
||||
InvalidPassword,
|
||||
|
||||
/// Successful match of id and password, removing from keychain and returning data `T`.
|
||||
Ok(T),
|
||||
}
|
||||
|
||||
impl<T> KeychainResult<T> {
|
||||
pub fn is_invalid_id(&self) -> bool {
|
||||
matches!(self, Self::InvalidId)
|
||||
}
|
||||
|
||||
pub fn is_invalid_password(&self) -> bool {
|
||||
matches!(self, Self::InvalidPassword)
|
||||
}
|
||||
|
||||
pub fn is_invalid(&self) -> bool {
|
||||
matches!(self, Self::InvalidId | Self::InvalidPassword)
|
||||
}
|
||||
|
||||
pub fn is_ok(&self) -> bool {
|
||||
matches!(self, Self::Ok(_))
|
||||
}
|
||||
|
||||
pub fn into_ok(self) -> Option<T> {
|
||||
match self {
|
||||
Self::Ok(x) => Some(x),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<KeychainResult<T>> for Option<T> {
|
||||
fn from(result: KeychainResult<T>) -> Self {
|
||||
result.into_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages keys with associated ids. Cloning will result in a copy pointing to the same underlying
|
||||
/// storage, which enables support of managing the keys across multiple threads.
|
||||
#[derive(Debug)]
|
||||
pub struct Keychain<T = ()> {
|
||||
map: Arc<RwLock<HashMap<String, (HeapSecretKey, T)>>>,
|
||||
}
|
||||
|
||||
impl<T> Clone for Keychain<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
map: Arc::clone(&self.map),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Keychain<T> {
|
||||
/// Creates a new keychain without any keys.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
map: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a new `key` and `data` by a given `id`, returning the old data associated with the
|
||||
/// id if there was one already registered.
|
||||
pub async fn insert(&self, id: impl Into<String>, key: HeapSecretKey, data: T) -> Option<T> {
|
||||
self.map
|
||||
.write()
|
||||
.await
|
||||
.insert(id.into(), (key, data))
|
||||
.map(|(_, data)| data)
|
||||
}
|
||||
|
||||
/// Checks if there is an `id` stored within the keychain.
|
||||
pub async fn has_id(&self, id: impl AsRef<str>) -> bool {
|
||||
self.map.read().await.contains_key(id.as_ref())
|
||||
}
|
||||
|
||||
/// Checks if there is a key with the given `id` that matches the provided `key`.
|
||||
pub async fn has_key(&self, id: impl AsRef<str>, key: impl PartialEq<HeapSecretKey>) -> bool {
|
||||
self.map
|
||||
.read()
|
||||
.await
|
||||
.get(id.as_ref())
|
||||
.map(|(k, _)| key.eq(k))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Removes a key and its data by a given `id`, returning the data if the `id` exists.
|
||||
pub async fn remove(&self, id: impl AsRef<str>) -> Option<T> {
|
||||
self.map
|
||||
.write()
|
||||
.await
|
||||
.remove(id.as_ref())
|
||||
.map(|(_, data)| data)
|
||||
}
|
||||
|
||||
/// Checks if there is a key with the given `id` that matches the provided `key`, returning the
|
||||
/// data if the `id` exists and the `key` matches.
|
||||
pub async fn remove_if_has_key(
|
||||
&self,
|
||||
id: impl AsRef<str>,
|
||||
key: impl PartialEq<HeapSecretKey>,
|
||||
) -> KeychainResult<T> {
|
||||
let id = id.as_ref();
|
||||
let mut lock = self.map.write().await;
|
||||
|
||||
match lock.get(id) {
|
||||
Some((k, _)) if key.eq(k) => KeychainResult::Ok(lock.remove(id).unwrap().1),
|
||||
Some(_) => KeychainResult::InvalidPassword,
|
||||
None => KeychainResult::InvalidId,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Keychain<()> {
|
||||
/// Stores a new `key by a given `id`.
|
||||
pub async fn put(&self, id: impl Into<String>, key: HeapSecretKey) {
|
||||
self.insert(id, key, ()).await;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Keychain {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<HashMap<String, (HeapSecretKey, T)>> for Keychain<T> {
|
||||
/// Creates a new keychain populated with the provided `map`.
|
||||
fn from(map: HashMap<String, (HeapSecretKey, T)>) -> Self {
|
||||
Self {
|
||||
map: Arc::new(RwLock::new(map)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HashMap<String, HeapSecretKey>> for Keychain<()> {
|
||||
/// Creates a new keychain populated with the provided `map`.
|
||||
fn from(map: HashMap<String, HeapSecretKey>) -> Self {
|
||||
Self::from(
|
||||
map.into_iter()
|
||||
.map(|(id, key)| (id, (key, ())))
|
||||
.collect::<HashMap<String, (HeapSecretKey, ())>>(),
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,376 @@
|
||||
use super::{super::HeapSecretKey, msg::*, Authenticator};
|
||||
use async_trait::async_trait;
|
||||
use log::*;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
|
||||
mod none;
|
||||
mod static_key;
|
||||
|
||||
pub use none::*;
|
||||
pub use static_key::*;
|
||||
|
||||
/// Supports authenticating using a variety of methods
|
||||
pub struct Verifier {
|
||||
methods: HashMap<&'static str, Box<dyn AuthenticationMethod>>,
|
||||
}
|
||||
|
||||
impl Verifier {
|
||||
pub fn new<I>(methods: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = Box<dyn AuthenticationMethod>>,
|
||||
{
|
||||
let mut m = HashMap::new();
|
||||
|
||||
for method in methods {
|
||||
m.insert(method.id(), method);
|
||||
}
|
||||
|
||||
Self { methods: m }
|
||||
}
|
||||
|
||||
/// Creates a verifier with no methods.
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
methods: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a verifier that uses the [`NoneAuthenticationMethod`] exclusively.
|
||||
pub fn none() -> Self {
|
||||
Self::new(vec![
|
||||
Box::new(NoneAuthenticationMethod::new()) as Box<dyn AuthenticationMethod>
|
||||
])
|
||||
}
|
||||
|
||||
/// Creates a verifier that uses the [`StaticKeyAuthenticationMethod`] exclusively.
|
||||
pub fn static_key(key: impl Into<HeapSecretKey>) -> Self {
|
||||
Self::new(vec![
|
||||
Box::new(StaticKeyAuthenticationMethod::new(key)) as Box<dyn AuthenticationMethod>
|
||||
])
|
||||
}
|
||||
|
||||
/// Returns an iterator over the ids of the methods supported by the verifier
|
||||
pub fn methods(&self) -> impl Iterator<Item = &'static str> + '_ {
|
||||
self.methods.keys().copied()
|
||||
}
|
||||
|
||||
/// Attempts to verify by submitting challenges using the `authenticator` provided. Returns the
|
||||
/// id of the authentication method that succeeded. Fails if no authentication method succeeds.
|
||||
pub async fn verify(&self, authenticator: &mut dyn Authenticator) -> io::Result<&'static str> {
|
||||
// Initiate the process to get methods to use
|
||||
let response = authenticator
|
||||
.initialize(Initialization {
|
||||
methods: self.methods.keys().map(ToString::to_string).collect(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
for method in response.methods {
|
||||
match self.methods.get(method.as_str()) {
|
||||
Some(method) => {
|
||||
// Report the authentication method
|
||||
authenticator
|
||||
.start_method(StartMethod {
|
||||
method: method.id().to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Perform the actual authentication
|
||||
if method.authenticate(authenticator).await.is_ok() {
|
||||
authenticator.finished().await?;
|
||||
return Ok(method.id());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
trace!("Skipping authentication {method} as it is not available or supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::PermissionDenied,
|
||||
"No authentication method succeeded",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Box<dyn AuthenticationMethod>>> for Verifier {
|
||||
fn from(methods: Vec<Box<dyn AuthenticationMethod>>) -> Self {
|
||||
Self::new(methods)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents an interface to authenticate using some method
|
||||
#[async_trait]
|
||||
pub trait AuthenticationMethod: Send + Sync {
|
||||
/// Returns a unique id to distinguish the method from other methods
|
||||
fn id(&self) -> &'static str;
|
||||
|
||||
/// Performs authentication using the `authenticator` to submit challenges and other
|
||||
/// information based on the authentication method
|
||||
async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::common::FramedTransport;
|
||||
use test_log::test;
|
||||
|
||||
struct SuccessAuthenticationMethod;
|
||||
|
||||
#[async_trait]
|
||||
impl AuthenticationMethod for SuccessAuthenticationMethod {
|
||||
fn id(&self) -> &'static str {
|
||||
"success"
|
||||
}
|
||||
|
||||
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct FailAuthenticationMethod;
|
||||
|
||||
#[async_trait]
|
||||
impl AuthenticationMethod for FailAuthenticationMethod {
|
||||
fn id(&self) -> &'static str {
|
||||
"fail"
|
||||
}
|
||||
|
||||
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
|
||||
Err(io::Error::from(io::ErrorKind::Other))
|
||||
}
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_initialization_fails() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame(b"invalid initialization response")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> =
|
||||
vec![Box::new(SuccessAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
verifier.verify(&mut t1).await.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_fails_to_send_finished_indicator_after_success() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Initialization(
|
||||
InitializationResponse {
|
||||
methods: vec![SuccessAuthenticationMethod.id().to_string()]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Then drop the transport so it cannot receive anything else
|
||||
drop(t2);
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> =
|
||||
vec![Box::new(SuccessAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
assert_eq!(
|
||||
verifier.verify(&mut t1).await.unwrap_err().kind(),
|
||||
io::ErrorKind::WriteZero
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_has_no_authentication_methods() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Initialization(
|
||||
InitializationResponse {
|
||||
methods: vec![SuccessAuthenticationMethod.id().to_string()]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![];
|
||||
let verifier = Verifier::from(methods);
|
||||
verifier.verify(&mut t1).await.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_initialization_yields_no_valid_authentication_methods(
|
||||
) {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Initialization(
|
||||
InitializationResponse {
|
||||
methods: vec!["other".to_string()].into_iter().collect(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> =
|
||||
vec![Box::new(SuccessAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
verifier.verify(&mut t1).await.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_fail_to_verify_if_no_authentication_method_succeeds() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Initialization(
|
||||
InitializationResponse {
|
||||
methods: vec![FailAuthenticationMethod.id().to_string()]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![Box::new(FailAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
verifier.verify(&mut t1).await.unwrap_err();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_return_id_of_authentication_method_upon_success() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Initialization(
|
||||
InitializationResponse {
|
||||
methods: vec![SuccessAuthenticationMethod.id().to_string()]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> =
|
||||
vec![Box::new(SuccessAuthenticationMethod)];
|
||||
let verifier = Verifier::from(methods);
|
||||
assert_eq!(
|
||||
verifier.verify(&mut t1).await.unwrap(),
|
||||
SuccessAuthenticationMethod.id()
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_try_authentication_methods_in_order_until_one_succeeds() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Initialization(
|
||||
InitializationResponse {
|
||||
methods: vec![
|
||||
FailAuthenticationMethod.id().to_string(),
|
||||
SuccessAuthenticationMethod.id().to_string(),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
|
||||
Box::new(FailAuthenticationMethod),
|
||||
Box::new(SuccessAuthenticationMethod),
|
||||
];
|
||||
let verifier = Verifier::from(methods);
|
||||
assert_eq!(
|
||||
verifier.verify(&mut t1).await.unwrap(),
|
||||
SuccessAuthenticationMethod.id()
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_send_start_method_before_attempting_each_method() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Initialization(
|
||||
InitializationResponse {
|
||||
methods: vec![
|
||||
FailAuthenticationMethod.id().to_string(),
|
||||
SuccessAuthenticationMethod.id().to_string(),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
|
||||
Box::new(FailAuthenticationMethod),
|
||||
Box::new(SuccessAuthenticationMethod),
|
||||
];
|
||||
Verifier::from(methods).verify(&mut t1).await.unwrap();
|
||||
|
||||
// Check that we get a start method for each of the attempted methods
|
||||
match t2.read_frame_as::<Authentication>().await.unwrap().unwrap() {
|
||||
Authentication::Initialization(_) => (),
|
||||
x => panic!("Unexpected response: {x:?}"),
|
||||
}
|
||||
match t2.read_frame_as::<Authentication>().await.unwrap().unwrap() {
|
||||
Authentication::StartMethod(x) => assert_eq!(x.method, FailAuthenticationMethod.id()),
|
||||
x => panic!("Unexpected response: {x:?}"),
|
||||
}
|
||||
match t2.read_frame_as::<Authentication>().await.unwrap().unwrap() {
|
||||
Authentication::StartMethod(x) => {
|
||||
assert_eq!(x.method, SuccessAuthenticationMethod.id())
|
||||
}
|
||||
x => panic!("Unexpected response: {x:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn verifier_should_send_finished_when_a_method_succeeds() {
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Initialization(
|
||||
InitializationResponse {
|
||||
methods: vec![
|
||||
FailAuthenticationMethod.id().to_string(),
|
||||
SuccessAuthenticationMethod.id().to_string(),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let methods: Vec<Box<dyn AuthenticationMethod>> = vec![
|
||||
Box::new(FailAuthenticationMethod),
|
||||
Box::new(SuccessAuthenticationMethod),
|
||||
];
|
||||
Verifier::from(methods).verify(&mut t1).await.unwrap();
|
||||
|
||||
// Clear out the initialization and start methods
|
||||
t2.read_frame_as::<Authentication>().await.unwrap().unwrap();
|
||||
t2.read_frame_as::<Authentication>().await.unwrap().unwrap();
|
||||
t2.read_frame_as::<Authentication>().await.unwrap().unwrap();
|
||||
|
||||
match t2.read_frame_as::<Authentication>().await.unwrap().unwrap() {
|
||||
Authentication::Finished => (),
|
||||
x => panic!("Unexpected response: {x:?}"),
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
use super::{AuthenticationMethod, Authenticator};
|
||||
use async_trait::async_trait;
|
||||
use std::io;
|
||||
|
||||
/// Authenticaton method for a static secret key
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NoneAuthenticationMethod;
|
||||
|
||||
impl NoneAuthenticationMethod {
|
||||
#[inline]
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NoneAuthenticationMethod {
|
||||
#[inline]
|
||||
fn default() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthenticationMethod for NoneAuthenticationMethod {
|
||||
fn id(&self) -> &'static str {
|
||||
"none"
|
||||
}
|
||||
|
||||
async fn authenticate(&self, _: &mut dyn Authenticator) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -0,0 +1,129 @@
|
||||
use super::{AuthenticationMethod, Authenticator, Challenge, Error, Question};
|
||||
use crate::common::HeapSecretKey;
|
||||
use async_trait::async_trait;
|
||||
use std::io;
|
||||
|
||||
/// Authenticaton method for a static secret key
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StaticKeyAuthenticationMethod {
|
||||
key: HeapSecretKey,
|
||||
}
|
||||
|
||||
impl StaticKeyAuthenticationMethod {
|
||||
#[inline]
|
||||
pub fn new(key: impl Into<HeapSecretKey>) -> Self {
|
||||
Self { key: key.into() }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthenticationMethod for StaticKeyAuthenticationMethod {
|
||||
fn id(&self) -> &'static str {
|
||||
"static_key"
|
||||
}
|
||||
|
||||
async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()> {
|
||||
let response = authenticator
|
||||
.challenge(Challenge {
|
||||
questions: vec![Question {
|
||||
label: "key".to_string(),
|
||||
text: "Provide a key: ".to_string(),
|
||||
options: Default::default(),
|
||||
}],
|
||||
options: Default::default(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
if response.answers.is_empty() {
|
||||
return Err(Error::non_fatal("missing answer").into_io_permission_denied());
|
||||
}
|
||||
|
||||
match response
|
||||
.answers
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse::<HeapSecretKey>()
|
||||
{
|
||||
Ok(key) if key == self.key => Ok(()),
|
||||
_ => Err(Error::non_fatal("answer does not match key").into_io_permission_denied()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::common::{
|
||||
authentication::msg::{AuthenticationResponse, ChallengeResponse},
|
||||
FramedTransport,
|
||||
};
|
||||
use test_log::test;
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticate_should_fail_if_key_challenge_fails() {
|
||||
let method = StaticKeyAuthenticationMethod::new(b"".to_vec());
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up an invalid frame for our challenge to ensure it fails
|
||||
t2.write_frame(b"invalid initialization response")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
method.authenticate(&mut t1).await.unwrap_err().kind(),
|
||||
io::ErrorKind::InvalidData
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticate_should_fail_if_no_answer_included_in_challenge_response() {
|
||||
let method = StaticKeyAuthenticationMethod::new(b"".to_vec());
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Challenge(ChallengeResponse {
|
||||
answers: Vec::new(),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
method.authenticate(&mut t1).await.unwrap_err().kind(),
|
||||
io::ErrorKind::PermissionDenied
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticate_should_fail_if_answer_does_not_match_key() {
|
||||
let method = StaticKeyAuthenticationMethod::new(b"answer".to_vec());
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Challenge(ChallengeResponse {
|
||||
answers: vec![HeapSecretKey::from(b"some key".to_vec()).to_string()],
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
method.authenticate(&mut t1).await.unwrap_err().kind(),
|
||||
io::ErrorKind::PermissionDenied
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn authenticate_should_succeed_if_answer_matches_key() {
|
||||
let method = StaticKeyAuthenticationMethod::new(b"answer".to_vec());
|
||||
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
||||
|
||||
// Queue up a response to the initialization request
|
||||
t2.write_frame_for(&AuthenticationResponse::Challenge(ChallengeResponse {
|
||||
answers: vec![HeapSecretKey::from(b"answer".to_vec()).to_string()],
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
method.authenticate(&mut t1).await.unwrap();
|
||||
}
|
||||
}
|
@ -0,0 +1,216 @@
|
||||
use derive_more::{Display, Error, From};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Represents messages from an authenticator that act as initiators such as providing
|
||||
/// a challenge, verifying information, presenting information, or highlighting an error
|
||||
#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum Authentication {
|
||||
/// Indicates the beginning of authentication, providing available methods
|
||||
#[serde(rename = "auth_initialization")]
|
||||
Initialization(Initialization),
|
||||
|
||||
/// Indicates that authentication is starting for the specific `method`
|
||||
#[serde(rename = "auth_start_method")]
|
||||
StartMethod(StartMethod),
|
||||
|
||||
/// Issues a challenge to be answered
|
||||
#[serde(rename = "auth_challenge")]
|
||||
Challenge(Challenge),
|
||||
|
||||
/// Requests verification of some text
|
||||
#[serde(rename = "auth_verification")]
|
||||
Verification(Verification),
|
||||
|
||||
/// Reports some information associated with authentication
|
||||
#[serde(rename = "auth_info")]
|
||||
Info(Info),
|
||||
|
||||
/// Reports an error occurrred during authentication
|
||||
#[serde(rename = "auth_error")]
|
||||
Error(Error),
|
||||
|
||||
/// Indicates that the authentication of all methods is finished
|
||||
#[serde(rename = "auth_finished")]
|
||||
Finished,
|
||||
}
|
||||
|
||||
/// Represents the beginning of the authentication procedure
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Initialization {
|
||||
/// Available methods to use for authentication
|
||||
pub methods: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents the start of authentication for some method
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct StartMethod {
|
||||
pub method: String,
|
||||
}
|
||||
|
||||
/// Represents a challenge comprising a series of questions to be presented
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Challenge {
|
||||
pub questions: Vec<Question>,
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Represents an ask to verify some information
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Verification {
|
||||
pub kind: VerificationKind,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
/// Represents some information to be presented related to authentication
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Info {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
/// Represents authentication messages that are responses to authenticator requests such
|
||||
/// as answers to challenges or verifying information
|
||||
#[derive(Clone, Debug, From, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum AuthenticationResponse {
|
||||
/// Contains response to initialization, providing details about which methods to use
|
||||
#[serde(rename = "auth_initialization_response")]
|
||||
Initialization(InitializationResponse),
|
||||
|
||||
/// Contains answers to challenge request
|
||||
#[serde(rename = "auth_challenge_response")]
|
||||
Challenge(ChallengeResponse),
|
||||
|
||||
/// Contains response to a verification request
|
||||
#[serde(rename = "auth_verification_response")]
|
||||
Verification(VerificationResponse),
|
||||
}
|
||||
|
||||
/// Represents a response to initialization to specify which authentication methods to pursue
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct InitializationResponse {
|
||||
/// Methods to use (in order as provided)
|
||||
pub methods: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents the answers to a previously-asked challenge associated with authentication
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ChallengeResponse {
|
||||
/// Answers to challenge questions (in order relative to questions)
|
||||
pub answers: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents the answer to a previously-asked verification associated with authentication
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct VerificationResponse {
|
||||
/// Whether or not the verification was deemed valid
|
||||
pub valid: bool,
|
||||
}
|
||||
|
||||
/// Represents the type of verification being requested
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum VerificationKind {
|
||||
/// An ask to verify the host such as with SSH
|
||||
#[display(fmt = "host")]
|
||||
Host,
|
||||
|
||||
/// When the verification is unknown (happens when other side is unaware of the kind)
|
||||
#[display(fmt = "unknown")]
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl VerificationKind {
|
||||
/// Returns all variants except "unknown"
|
||||
pub const fn known_variants() -> &'static [Self] {
|
||||
&[Self::Host]
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a single question in a challenge associated with authentication
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Question {
|
||||
/// Label associated with the question for more programmatic usage
|
||||
pub label: String,
|
||||
|
||||
/// The text of the question (used for display purposes)
|
||||
pub text: String,
|
||||
|
||||
/// Any options information specific to a particular auth domain
|
||||
/// such as including a username and instructions for SSH authentication
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Question {
|
||||
/// Creates a new question without any options data using `text` for both label and text
|
||||
pub fn new(text: impl Into<String>) -> Self {
|
||||
let text = text.into();
|
||||
|
||||
Self {
|
||||
label: text.clone(),
|
||||
text,
|
||||
options: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents some error that occurred during authentication
|
||||
#[derive(Clone, Debug, Display, Error, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[display(fmt = "{}: {}", kind, text)]
|
||||
pub struct Error {
|
||||
/// Represents the kind of error
|
||||
pub kind: ErrorKind,
|
||||
|
||||
/// Description of the error
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Creates a fatal error
|
||||
pub fn fatal(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Fatal,
|
||||
text: text.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a non-fatal error
|
||||
pub fn non_fatal(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Error,
|
||||
text: text.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if error represents a fatal error, meaning that there is no recovery possible
|
||||
/// from this error
|
||||
pub fn is_fatal(&self) -> bool {
|
||||
self.kind.is_fatal()
|
||||
}
|
||||
|
||||
/// Converts the error into a [`std::io::Error`] representing permission denied
|
||||
pub fn into_io_permission_denied(self) -> std::io::Error {
|
||||
std::io::Error::new(std::io::ErrorKind::PermissionDenied, self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the type of error encountered during authentication
|
||||
#[derive(Copy, Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ErrorKind {
|
||||
/// Error is unrecoverable
|
||||
Fatal,
|
||||
|
||||
/// Error is recoverable
|
||||
Error,
|
||||
}
|
||||
|
||||
impl ErrorKind {
|
||||
/// Returns true if error kind represents a fatal error, meaning that there is no recovery
|
||||
/// possible from this error
|
||||
pub fn is_fatal(self) -> bool {
|
||||
matches!(self, Self::Fatal)
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
use crate::Listener;
|
||||
use super::Listener;
|
||||
use async_trait::async_trait;
|
||||
use std::io;
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::Listener;
|
||||
use super::Listener;
|
||||
use async_trait::async_trait;
|
||||
use derive_more::From;
|
||||
use std::io;
|
@ -0,0 +1,628 @@
|
||||
/// Represents a generic id type
|
||||
pub type Id = String;
|
||||
|
||||
mod request;
|
||||
mod response;
|
||||
|
||||
pub use request::*;
|
||||
pub use response::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
enum MsgPackStrParseError {
|
||||
InvalidFormat,
|
||||
Utf8Error(std::str::Utf8Error),
|
||||
}
|
||||
|
||||
/// Writes the given str to the end of `buf` as the str's msgpack representation.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `s.len() >= 2 ^ 32` as the maximum str length for a msgpack str is `(2 ^ 32) - 1`.
|
||||
fn write_str_msg_pack(s: &str, buf: &mut Vec<u8>) {
|
||||
assert!(
|
||||
s.len() < 2usize.pow(32),
|
||||
"str cannot be longer than (2^32)-1 bytes"
|
||||
);
|
||||
|
||||
if s.len() < 32 {
|
||||
buf.push(s.len() as u8 | 0b10100000);
|
||||
} else if s.len() < 2usize.pow(8) {
|
||||
buf.push(0xd9);
|
||||
buf.push(s.len() as u8);
|
||||
} else if s.len() < 2usize.pow(16) {
|
||||
buf.push(0xda);
|
||||
for b in (s.len() as u16).to_be_bytes() {
|
||||
buf.push(b);
|
||||
}
|
||||
} else {
|
||||
buf.push(0xdb);
|
||||
for b in (s.len() as u32).to_be_bytes() {
|
||||
buf.push(b);
|
||||
}
|
||||
}
|
||||
|
||||
buf.extend_from_slice(s.as_bytes());
|
||||
}
|
||||
|
||||
/// Parse msgpack str, returning remaining bytes and str on success, or error on failure.
|
||||
fn parse_msg_pack_str(input: &[u8]) -> Result<(&[u8], &str), MsgPackStrParseError> {
|
||||
let ilen = input.len();
|
||||
if ilen == 0 {
|
||||
return Err(MsgPackStrParseError::InvalidFormat);
|
||||
}
|
||||
|
||||
// * fixstr using 0xa0 - 0xbf to mark the start of the str where < 32 bytes
|
||||
// * str 8 (0xd9) if up to (2^8)-1 bytes, using next byte for len
|
||||
// * str 16 (0xda) if up to (2^16)-1 bytes, using next two bytes for len
|
||||
// * str 32 (0xdb) if up to (2^32)-1 bytes, using next four bytes for len
|
||||
let (input, len): (&[u8], usize) = if input[0] >= 0xa0 && input[0] <= 0xbf {
|
||||
(&input[1..], (input[0] & 0b00011111).into())
|
||||
} else if input[0] == 0xd9 && ilen > 2 {
|
||||
(&input[2..], input[1].into())
|
||||
} else if input[0] == 0xda && ilen > 3 {
|
||||
(&input[3..], u16::from_be_bytes([input[1], input[2]]).into())
|
||||
} else if input[0] == 0xdb && ilen > 5 {
|
||||
(
|
||||
&input[5..],
|
||||
u32::from_be_bytes([input[1], input[2], input[3], input[4]])
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
)
|
||||
} else {
|
||||
return Err(MsgPackStrParseError::InvalidFormat);
|
||||
};
|
||||
|
||||
let s = match std::str::from_utf8(&input[..len]) {
|
||||
Ok(s) => s,
|
||||
Err(x) => return Err(MsgPackStrParseError::Utf8Error(x)),
|
||||
};
|
||||
|
||||
Ok((&input[len..], s))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
mod write_str_msg_pack {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_support_fixstr() {
|
||||
// 0-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("", &mut buf);
|
||||
assert_eq!(buf, &[0xa0]);
|
||||
|
||||
// 1-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("a", &mut buf);
|
||||
assert_eq!(buf, &[0xa1, b'a']);
|
||||
|
||||
// 2-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("ab", &mut buf);
|
||||
assert_eq!(buf, &[0xa2, b'a', b'b']);
|
||||
|
||||
// 3-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abc", &mut buf);
|
||||
assert_eq!(buf, &[0xa3, b'a', b'b', b'c']);
|
||||
|
||||
// 4-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcd", &mut buf);
|
||||
assert_eq!(buf, &[0xa4, b'a', b'b', b'c', b'd']);
|
||||
|
||||
// 5-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcde", &mut buf);
|
||||
assert_eq!(buf, &[0xa5, b'a', b'b', b'c', b'd', b'e']);
|
||||
|
||||
// 6-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdef", &mut buf);
|
||||
assert_eq!(buf, &[0xa6, b'a', b'b', b'c', b'd', b'e', b'f']);
|
||||
|
||||
// 7-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefg", &mut buf);
|
||||
assert_eq!(buf, &[0xa7, b'a', b'b', b'c', b'd', b'e', b'f', b'g']);
|
||||
|
||||
// 8-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefgh", &mut buf);
|
||||
assert_eq!(buf, &[0xa8, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h']);
|
||||
|
||||
// 9-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghi", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[0xa9, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']
|
||||
);
|
||||
|
||||
// 10-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghij", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[0xaa, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j']
|
||||
);
|
||||
|
||||
// 11-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijk", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[0xab, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k']
|
||||
);
|
||||
|
||||
// 12-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijkl", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[0xac, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l']
|
||||
);
|
||||
|
||||
// 13-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklm", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xad, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm'
|
||||
]
|
||||
);
|
||||
|
||||
// 14-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmn", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xae, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n'
|
||||
]
|
||||
);
|
||||
|
||||
// 15-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmno", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xaf, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o'
|
||||
]
|
||||
);
|
||||
|
||||
// 16-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnop", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb0, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p'
|
||||
]
|
||||
);
|
||||
|
||||
// 17-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopq", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb1, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q'
|
||||
]
|
||||
);
|
||||
|
||||
// 18-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqr", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb2, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r'
|
||||
]
|
||||
);
|
||||
|
||||
// 19-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrs", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb3, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's'
|
||||
]
|
||||
);
|
||||
|
||||
// 20-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrst", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb4, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't'
|
||||
]
|
||||
);
|
||||
|
||||
// 21-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstu", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb5, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u'
|
||||
]
|
||||
);
|
||||
|
||||
// 22-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuv", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb6, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v'
|
||||
]
|
||||
);
|
||||
|
||||
// 23-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuvw", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb7, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w'
|
||||
]
|
||||
);
|
||||
|
||||
// 24-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuvwx", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb8, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x'
|
||||
]
|
||||
);
|
||||
|
||||
// 25-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuvwxy", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xb9, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y'
|
||||
]
|
||||
);
|
||||
|
||||
// 26-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xba, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
|
||||
b'z'
|
||||
]
|
||||
);
|
||||
|
||||
// 27-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz0", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xbb, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
|
||||
b'z', b'0'
|
||||
]
|
||||
);
|
||||
|
||||
// 28-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz01", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xbc, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
|
||||
b'z', b'0', b'1'
|
||||
]
|
||||
);
|
||||
|
||||
// 29-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz012", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xbd, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
|
||||
b'z', b'0', b'1', b'2'
|
||||
]
|
||||
);
|
||||
|
||||
// 30-byte str
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz0123", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xbe, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
|
||||
b'z', b'0', b'1', b'2', b'3'
|
||||
]
|
||||
);
|
||||
|
||||
// 31-byte str is maximum len of fixstr
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack("abcdefghijklmnopqrstuvwxyz01234", &mut buf);
|
||||
assert_eq!(
|
||||
buf,
|
||||
&[
|
||||
0xbf, b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
|
||||
b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y',
|
||||
b'z', b'0', b'1', b'2', b'3', b'4'
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_str_8() {
|
||||
let input = "a".repeat(32);
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack(&input, &mut buf);
|
||||
assert_eq!(buf[0], 0xd9);
|
||||
assert_eq!(buf[1], input.len() as u8);
|
||||
assert_eq!(&buf[2..], input.as_bytes());
|
||||
|
||||
let input = "a".repeat(2usize.pow(8) - 1);
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack(&input, &mut buf);
|
||||
assert_eq!(buf[0], 0xd9);
|
||||
assert_eq!(buf[1], input.len() as u8);
|
||||
assert_eq!(&buf[2..], input.as_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_str_16() {
|
||||
let input = "a".repeat(2usize.pow(8));
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack(&input, &mut buf);
|
||||
assert_eq!(buf[0], 0xda);
|
||||
assert_eq!(&buf[1..3], &(input.len() as u16).to_be_bytes());
|
||||
assert_eq!(&buf[3..], input.as_bytes());
|
||||
|
||||
let input = "a".repeat(2usize.pow(16) - 1);
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack(&input, &mut buf);
|
||||
assert_eq!(buf[0], 0xda);
|
||||
assert_eq!(&buf[1..3], &(input.len() as u16).to_be_bytes());
|
||||
assert_eq!(&buf[3..], input.as_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_str_32() {
|
||||
let input = "a".repeat(2usize.pow(16));
|
||||
let mut buf = Vec::new();
|
||||
write_str_msg_pack(&input, &mut buf);
|
||||
assert_eq!(buf[0], 0xdb);
|
||||
assert_eq!(&buf[1..5], &(input.len() as u32).to_be_bytes());
|
||||
assert_eq!(&buf[5..], input.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
mod parse_msg_pack_str {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn should_be_able_to_parse_fixstr() {
|
||||
// Empty str
|
||||
let (input, s) = parse_msg_pack_str(&[0xa0]).unwrap();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(s, "");
|
||||
|
||||
// Single character
|
||||
let (input, s) = parse_msg_pack_str(&[0xa1, b'a']).unwrap();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(s, "a");
|
||||
|
||||
// 31 byte str
|
||||
let (input, s) = parse_msg_pack_str(&[
|
||||
0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
|
||||
b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
|
||||
b'a', b'a', b'a', b'a',
|
||||
])
|
||||
.unwrap();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
|
||||
|
||||
// Verify that we only consume up to fixstr length
|
||||
assert_eq!(parse_msg_pack_str(&[0xa0, b'a']).unwrap().0, b"a");
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[
|
||||
0xbf, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
|
||||
b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
|
||||
b'a', b'a', b'a', b'a', b'a', b'a', b'b'
|
||||
])
|
||||
.unwrap()
|
||||
.0,
|
||||
b"b"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_be_able_to_parse_str_8() {
|
||||
// 32 byte str
|
||||
let (input, s) = parse_msg_pack_str(&[
|
||||
0xd9, 32, b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
|
||||
b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a', b'a',
|
||||
b'a', b'a', b'a', b'a', b'a', b'a',
|
||||
])
|
||||
.unwrap();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(s, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
|
||||
|
||||
// 2^8 - 1 (255) byte str
|
||||
let test_str = "a".repeat(2usize.pow(8) - 1);
|
||||
let mut input = vec![0xd9, 255];
|
||||
input.extend_from_slice(test_str.as_bytes());
|
||||
let (input, s) = parse_msg_pack_str(&input).unwrap();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(s, test_str);
|
||||
|
||||
// Verify that we only consume up to 2^8 - 1 length
|
||||
let mut input = vec![0xd9, 255];
|
||||
input.extend_from_slice(test_str.as_bytes());
|
||||
input.extend_from_slice(b"hello");
|
||||
let (input, s) = parse_msg_pack_str(&input).unwrap();
|
||||
assert_eq!(input, b"hello");
|
||||
assert_eq!(s, test_str);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_be_able_to_parse_str_16() {
|
||||
// 2^8 byte str (256)
|
||||
let test_str = "a".repeat(2usize.pow(8));
|
||||
let mut input = vec![0xda, 1, 0];
|
||||
input.extend_from_slice(test_str.as_bytes());
|
||||
let (input, s) = parse_msg_pack_str(&input).unwrap();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(s, test_str);
|
||||
|
||||
// 2^16 - 1 (65535) byte str
|
||||
let test_str = "a".repeat(2usize.pow(16) - 1);
|
||||
let mut input = vec![0xda, 255, 255];
|
||||
input.extend_from_slice(test_str.as_bytes());
|
||||
let (input, s) = parse_msg_pack_str(&input).unwrap();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(s, test_str);
|
||||
|
||||
// Verify that we only consume up to 2^16 - 1 length
|
||||
let mut input = vec![0xda, 255, 255];
|
||||
input.extend_from_slice(test_str.as_bytes());
|
||||
input.extend_from_slice(b"hello");
|
||||
let (input, s) = parse_msg_pack_str(&input).unwrap();
|
||||
assert_eq!(input, b"hello");
|
||||
assert_eq!(s, test_str);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_be_able_to_parse_str_32() {
|
||||
// 2^16 byte str
|
||||
let test_str = "a".repeat(2usize.pow(16));
|
||||
let mut input = vec![0xdb, 0, 1, 0, 0];
|
||||
input.extend_from_slice(test_str.as_bytes());
|
||||
let (input, s) = parse_msg_pack_str(&input).unwrap();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(s, test_str);
|
||||
|
||||
// NOTE: We are not going to run the below tests, not because they aren't valid but
|
||||
// because this generates a 4GB str which takes 20+ seconds to run
|
||||
|
||||
// 2^32 - 1 byte str (4294967295 bytes)
|
||||
/* let test_str = "a".repeat(2usize.pow(32) - 1);
|
||||
let mut input = vec![0xdb, 255, 255, 255, 255];
|
||||
input.extend_from_slice(test_str.as_bytes());
|
||||
let (input, s) = parse_msg_pack_str(&input).unwrap();
|
||||
assert!(input.is_empty());
|
||||
assert_eq!(s, test_str); */
|
||||
|
||||
// Verify that we only consume up to 2^32 - 1 length
|
||||
/* let mut input = vec![0xdb, 255, 255, 255, 255];
|
||||
input.extend_from_slice(test_str.as_bytes());
|
||||
input.extend_from_slice(b"hello");
|
||||
let (input, s) = parse_msg_pack_str(&input).unwrap();
|
||||
assert_eq!(input, b"hello");
|
||||
assert_eq!(s, test_str); */
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_fail_parsing_str_with_invalid_length() {
|
||||
// Make sure that parse doesn't fail looking for bytes after str 8 len
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xd9]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xd9, 0]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
|
||||
// Make sure that parse doesn't fail looking for bytes after str 16 len
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xda]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xda, 0]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xda, 0, 0]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
|
||||
// Make sure that parse doesn't fail looking for bytes after str 32 len
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xdb]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xdb, 0]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xdb, 0, 0]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xdb, 0, 0, 0]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xdb, 0, 0, 0, 0]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_fail_parsing_other_types() {
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[0xc3]), // Boolean (true)
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_fail_if_empty_input() {
|
||||
assert_eq!(
|
||||
parse_msg_pack_str(&[]),
|
||||
Err(MsgPackStrParseError::InvalidFormat)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_fail_if_str_is_not_utf8() {
|
||||
assert!(matches!(
|
||||
parse_msg_pack_str(&[0xa4, 0, 159, 146, 150]),
|
||||
Err(MsgPackStrParseError::Utf8Error(_))
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,629 @@
|
||||
use async_trait::async_trait;
|
||||
use std::{io, time::Duration};
|
||||
|
||||
mod framed;
|
||||
pub use framed::*;
|
||||
|
||||
mod inmemory;
|
||||
pub use inmemory::*;
|
||||
|
||||
mod tcp;
|
||||
pub use tcp::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use test::*;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use unix::*;
|
||||
|
||||
#[cfg(windows)]
|
||||
mod windows;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub use windows::*;
|
||||
|
||||
pub use tokio::io::{Interest, Ready};
|
||||
|
||||
/// Duration to wait after WouldBlock received during looping operations like `read_exact`.
|
||||
const SLEEP_DURATION: Duration = Duration::from_millis(1);
|
||||
|
||||
/// Interface representing a connection that is reconnectable.
|
||||
#[async_trait]
|
||||
pub trait Reconnectable {
|
||||
/// Attempts to reconnect an already-established connection.
|
||||
async fn reconnect(&mut self) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// Interface representing a transport of raw bytes into and out of the system.
|
||||
#[async_trait]
|
||||
pub trait Transport: Reconnectable + Send + Sync {
|
||||
/// Tries to read data from the transport into the provided buffer, returning how many bytes
|
||||
/// were read.
|
||||
///
|
||||
/// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
|
||||
/// is not ready to read data.
|
||||
///
|
||||
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
|
||||
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize>;
|
||||
|
||||
/// Try to write a buffer to the transport, returning how many bytes were written.
|
||||
///
|
||||
/// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
|
||||
/// is not ready to write data.
|
||||
///
|
||||
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
|
||||
fn try_write(&self, buf: &[u8]) -> io::Result<usize>;
|
||||
|
||||
/// Waits for the transport to be ready based on the given interest, returning the ready
|
||||
/// status.
|
||||
async fn ready(&self, interest: Interest) -> io::Result<Ready>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for Box<dyn Transport> {
|
||||
fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
Transport::try_read(AsRef::as_ref(self), buf)
|
||||
}
|
||||
|
||||
fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
|
||||
Transport::try_write(AsRef::as_ref(self), buf)
|
||||
}
|
||||
|
||||
async fn ready(&self, interest: Interest) -> io::Result<Ready> {
|
||||
Transport::ready(AsRef::as_ref(self), interest).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Reconnectable for Box<dyn Transport> {
|
||||
async fn reconnect(&mut self) -> io::Result<()> {
|
||||
Reconnectable::reconnect(AsMut::as_mut(self)).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait TransportExt {
|
||||
/// Waits for the transport to be readable to follow up with `try_read`.
|
||||
async fn readable(&self) -> io::Result<()>;
|
||||
|
||||
/// Waits for the transport to be writeable to follow up with `try_write`.
|
||||
async fn writeable(&self) -> io::Result<()>;
|
||||
|
||||
/// Waits for the transport to be either readable or writeable.
|
||||
async fn readable_or_writeable(&self) -> io::Result<()>;
|
||||
|
||||
/// Reads exactly `n` bytes where `n` is the length of `buf` by continuing to call [`try_read`]
|
||||
/// until completed. Calls to [`readable`] are made to ensure the transport is ready. Returns
|
||||
/// the total bytes read.
|
||||
///
|
||||
/// [`try_read`]: Transport::try_read
|
||||
/// [`readable`]: Transport::readable
|
||||
async fn read_exact(&self, buf: &mut [u8]) -> io::Result<usize>;
|
||||
|
||||
/// Reads all bytes until EOF in this source, placing them into `buf`.
|
||||
///
|
||||
/// All bytes read from this source will be appended to the specified buffer `buf`. This
|
||||
/// function will continuously call [`try_read`] to append more data to `buf` until
|
||||
/// [`try_read`] returns either [`Ok(0)`] or an error that is neither [`Interrupted`] or
|
||||
/// [`WouldBlock`].
|
||||
///
|
||||
/// If successful, this function will return the total number of bytes read.
|
||||
///
|
||||
/// ### Errors
|
||||
///
|
||||
/// If this function encounters an error of the kind [`Interrupted`] or [`WouldBlock`], then
|
||||
/// the error is ignored and the operation will continue.
|
||||
///
|
||||
/// If any other read error is encountered then this function immediately returns. Any bytes
|
||||
/// which have already been read will be appended to `buf`.
|
||||
///
|
||||
/// [`Ok(0)`]: Ok
|
||||
/// [`try_read`]: Transport::try_read
|
||||
/// [`readable`]: Transport::readable
|
||||
async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize>;
|
||||
|
||||
/// Reads all bytes until EOF in this source, placing them into `buf`.
|
||||
///
|
||||
/// If successful, this function will return the total number of bytes read.
|
||||
///
|
||||
/// ### Errors
|
||||
///
|
||||
/// If the data in this stream is *not* valid UTF-8 then an error is returned and `buf` is
|
||||
/// unchanged.
|
||||
///
|
||||
/// See [`read_to_end`] for other error semantics.
|
||||
///
|
||||
/// [`Ok(0)`]: Ok
|
||||
/// [`try_read`]: Transport::try_read
|
||||
/// [`readable`]: Transport::readable
|
||||
/// [`read_to_end`]: TransportExt::read_to_end
|
||||
async fn read_to_string(&self, buf: &mut String) -> io::Result<usize>;
|
||||
|
||||
/// Writes all of `buf` by continuing to call [`try_write`] until completed. Calls to
|
||||
/// [`writeable`] are made to ensure the transport is ready.
|
||||
///
|
||||
/// [`try_write`]: Transport::try_write
|
||||
/// [`writable`]: Transport::writable
|
||||
async fn write_all(&self, buf: &[u8]) -> io::Result<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Transport> TransportExt for T {
|
||||
async fn readable(&self) -> io::Result<()> {
|
||||
self.ready(Interest::READABLE).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn writeable(&self) -> io::Result<()> {
|
||||
self.ready(Interest::WRITABLE).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn readable_or_writeable(&self) -> io::Result<()> {
|
||||
self.ready(Interest::READABLE | Interest::WRITABLE).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_exact(&self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let mut i = 0;
|
||||
|
||||
while i < buf.len() {
|
||||
self.readable().await?;
|
||||
|
||||
match self.try_read(&mut buf[i..]) {
|
||||
// If we get 0 bytes read, this usually means that the underlying reader
|
||||
// has closed, so we will return an EOF error to reflect that
|
||||
//
|
||||
// NOTE: `try_read` can also return 0 if the buf len is zero, but because we check
|
||||
// that our index is < len, the situation where we call try_read with a buf
|
||||
// of len 0 will never happen
|
||||
Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
|
||||
|
||||
Ok(n) => i += n,
|
||||
|
||||
// Because we are using `try_read`, it can be possible for it to return
|
||||
// WouldBlock; so, if we encounter that then we just wait for next readable
|
||||
Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
|
||||
// NOTE: We sleep for a little bit before trying again to avoid pegging CPU
|
||||
tokio::time::sleep(SLEEP_DURATION).await
|
||||
}
|
||||
|
||||
Err(x) => return Err(x),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(i)
|
||||
}
|
||||
|
||||
async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
|
||||
let mut i = 0;
|
||||
let mut tmp = [0u8; 1024];
|
||||
|
||||
loop {
|
||||
self.readable().await?;
|
||||
|
||||
match self.try_read(&mut tmp) {
|
||||
Ok(0) => return Ok(i),
|
||||
Ok(n) => {
|
||||
buf.extend_from_slice(&tmp[..n]);
|
||||
i += n;
|
||||
}
|
||||
Err(x)
|
||||
if x.kind() == io::ErrorKind::WouldBlock
|
||||
|| x.kind() == io::ErrorKind::Interrupted =>
|
||||
{
|
||||
// NOTE: We sleep for a little bit before trying again to avoid pegging CPU
|
||||
tokio::time::sleep(SLEEP_DURATION).await
|
||||
}
|
||||
|
||||
Err(x) => return Err(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_to_string(&self, buf: &mut String) -> io::Result<usize> {
|
||||
let mut tmp = Vec::new();
|
||||
let n = self.read_to_end(&mut tmp).await?;
|
||||
buf.push_str(
|
||||
&String::from_utf8(tmp).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?,
|
||||
);
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
async fn write_all(&self, buf: &[u8]) -> io::Result<()> {
|
||||
let mut i = 0;
|
||||
|
||||
while i < buf.len() {
|
||||
self.writeable().await?;
|
||||
|
||||
match self.try_write(&buf[i..]) {
|
||||
// If we get 0 bytes written, this usually means that the underlying writer
|
||||
// has closed, so we will return a write zero error to reflect that
|
||||
//
|
||||
// NOTE: `try_write` can also return 0 if the buf len is zero, but because we check
|
||||
// that our index is < len, the situation where we call try_write with a buf
|
||||
// of len 0 will never happen
|
||||
Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)),
|
||||
|
||||
Ok(n) => i += n,
|
||||
|
||||
// Because we are using `try_write`, it can be possible for it to return
|
||||
// WouldBlock; so, if we encounter that then we just wait for next writeable
|
||||
Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
|
||||
// NOTE: We sleep for a little bit before trying again to avoid pegging CPU
|
||||
tokio::time::sleep(SLEEP_DURATION).await
|
||||
}
|
||||
|
||||
Err(x) => return Err(x),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use test_log::test;
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_exact_should_fail_if_try_read_encounters_error_other_than_would_block() {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = [0; 1];
|
||||
assert_eq!(
|
||||
transport.read_exact(&mut buf).await.unwrap_err().kind(),
|
||||
io::ErrorKind::NotConnected
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_exact_should_fail_if_try_read_returns_0_before_necessary_bytes_read() {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|_| Ok(0)),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = [0; 1];
|
||||
assert_eq!(
|
||||
transport.read_exact(&mut buf).await.unwrap_err().kind(),
|
||||
io::ErrorKind::UnexpectedEof
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_exact_should_continue_to_call_try_read_until_buffer_is_filled() {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|buf| {
|
||||
static mut CNT: u8 = 0;
|
||||
unsafe {
|
||||
buf[0] = b'a' + CNT;
|
||||
CNT += 1;
|
||||
}
|
||||
Ok(1)
|
||||
}),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = [0; 3];
|
||||
assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
|
||||
assert_eq!(&buf, b"abc");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_exact_should_continue_to_call_try_read_while_it_returns_would_block() {
|
||||
// Configure `try_read` to alternate between reading a byte and WouldBlock
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|buf| {
|
||||
static mut CNT: u8 = 0;
|
||||
unsafe {
|
||||
buf[0] = b'a' + CNT;
|
||||
CNT += 1;
|
||||
if CNT % 2 == 1 {
|
||||
Ok(1)
|
||||
} else {
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
}
|
||||
}
|
||||
}),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = [0; 3];
|
||||
assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
|
||||
assert_eq!(&buf, b"ace");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_exact_should_return_0_if_given_a_buffer_of_0_len() {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = [0; 0];
|
||||
assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_to_end_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt(
|
||||
) {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
transport
|
||||
.read_to_end(&mut Vec::new())
|
||||
.await
|
||||
.unwrap_err()
|
||||
.kind(),
|
||||
io::ErrorKind::NotConnected
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_to_end_should_read_until_0_bytes_returned_from_try_read() {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|buf| {
|
||||
static mut CNT: u8 = 0;
|
||||
unsafe {
|
||||
if CNT == 0 {
|
||||
buf[..5].copy_from_slice(b"hello");
|
||||
CNT += 1;
|
||||
Ok(5)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
}),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = Vec::new();
|
||||
assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 5);
|
||||
assert_eq!(buf, b"hello");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_to_end_should_continue_reading_when_interrupt_or_would_block_encountered() {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|buf| {
|
||||
static mut CNT: u8 = 0;
|
||||
unsafe {
|
||||
CNT += 1;
|
||||
if CNT == 1 {
|
||||
buf[..6].copy_from_slice(b"hello ");
|
||||
Ok(6)
|
||||
} else if CNT == 2 {
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
} else if CNT == 3 {
|
||||
buf[..5].copy_from_slice(b"world");
|
||||
Ok(5)
|
||||
} else if CNT == 4 {
|
||||
Err(io::Error::from(io::ErrorKind::Interrupted))
|
||||
} else if CNT == 5 {
|
||||
buf[..6].copy_from_slice(b", test");
|
||||
Ok(6)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
}),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = Vec::new();
|
||||
assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 17);
|
||||
assert_eq!(buf, b"hello world, test");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_to_string_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt(
|
||||
) {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
transport
|
||||
.read_to_string(&mut String::new())
|
||||
.await
|
||||
.unwrap_err()
|
||||
.kind(),
|
||||
io::ErrorKind::NotConnected
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_to_string_should_fail_if_non_utf8_characters_read() {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|buf| {
|
||||
static mut CNT: u8 = 0;
|
||||
unsafe {
|
||||
if CNT == 0 {
|
||||
buf[0] = 0;
|
||||
buf[1] = 159;
|
||||
buf[2] = 146;
|
||||
buf[3] = 150;
|
||||
CNT += 1;
|
||||
Ok(4)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
}),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = String::new();
|
||||
assert_eq!(
|
||||
transport.read_to_string(&mut buf).await.unwrap_err().kind(),
|
||||
io::ErrorKind::InvalidData
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_to_string_should_read_until_0_bytes_returned_from_try_read() {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|buf| {
|
||||
static mut CNT: u8 = 0;
|
||||
unsafe {
|
||||
if CNT == 0 {
|
||||
buf[..5].copy_from_slice(b"hello");
|
||||
CNT += 1;
|
||||
Ok(5)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
}),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = String::new();
|
||||
assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 5);
|
||||
assert_eq!(buf, "hello");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn read_to_string_should_continue_reading_when_interrupt_or_would_block_encountered() {
|
||||
let transport = TestTransport {
|
||||
f_try_read: Box::new(|buf| {
|
||||
static mut CNT: u8 = 0;
|
||||
unsafe {
|
||||
CNT += 1;
|
||||
if CNT == 1 {
|
||||
buf[..6].copy_from_slice(b"hello ");
|
||||
Ok(6)
|
||||
} else if CNT == 2 {
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
} else if CNT == 3 {
|
||||
buf[..5].copy_from_slice(b"world");
|
||||
Ok(5)
|
||||
} else if CNT == 4 {
|
||||
Err(io::Error::from(io::ErrorKind::Interrupted))
|
||||
} else if CNT == 5 {
|
||||
buf[..6].copy_from_slice(b", test");
|
||||
Ok(6)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
}),
|
||||
f_ready: Box::new(|_| Ok(Ready::READABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = String::new();
|
||||
assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 17);
|
||||
assert_eq!(buf, "hello world, test");
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn write_all_should_fail_if_try_write_encounters_error_other_than_would_block() {
|
||||
let transport = TestTransport {
|
||||
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
|
||||
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
transport.write_all(b"abc").await.unwrap_err().kind(),
|
||||
io::ErrorKind::NotConnected
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn write_all_should_fail_if_try_write_returns_0_before_all_bytes_written() {
|
||||
let transport = TestTransport {
|
||||
f_try_write: Box::new(|_| Ok(0)),
|
||||
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
transport.write_all(b"abc").await.unwrap_err().kind(),
|
||||
io::ErrorKind::WriteZero
|
||||
);
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn write_all_should_continue_to_call_try_write_until_all_bytes_written() {
|
||||
// Configure `try_write` to alternate between writing a byte and WouldBlock
|
||||
let transport = TestTransport {
|
||||
f_try_write: Box::new(|buf| {
|
||||
static mut CNT: u8 = 0;
|
||||
unsafe {
|
||||
assert_eq!(buf[0], b'a' + CNT);
|
||||
CNT += 1;
|
||||
Ok(1)
|
||||
}
|
||||
}),
|
||||
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
transport.write_all(b"abc").await.unwrap();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn write_all_should_continue_to_call_try_write_while_it_returns_would_block() {
|
||||
// Configure `try_write` to alternate between writing a byte and WouldBlock
|
||||
let transport = TestTransport {
|
||||
f_try_write: Box::new(|buf| {
|
||||
static mut CNT: u8 = 0;
|
||||
unsafe {
|
||||
if CNT % 2 == 0 {
|
||||
assert_eq!(buf[0], b'a' + CNT);
|
||||
CNT += 1;
|
||||
Ok(1)
|
||||
} else {
|
||||
CNT += 1;
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
}
|
||||
}
|
||||
}),
|
||||
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
transport.write_all(b"ace").await.unwrap();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn write_all_should_return_immediately_if_given_buffer_of_0_len() {
|
||||
let transport = TestTransport {
|
||||
f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
|
||||
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// No error takes place as we never call try_write
|
||||
let buf = [0; 0];
|
||||
transport.write_all(&buf).await.unwrap();
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,201 @@
|
||||
use super::{Frame, OwnedFrame};
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Maximum size (in bytes) for saved frames (256MiB)
|
||||
const MAX_BACKUP_SIZE: usize = 256 * 1024 * 1024;
|
||||
|
||||
/// Stores [`Frame`]s for reuse later.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct Backup {
|
||||
/// Maximum size (in bytes) to save frames in case we need to backup them
|
||||
///
|
||||
/// NOTE: If 0, no frames will be stored.
|
||||
max_backup_size: usize,
|
||||
|
||||
/// Tracker for the total size (in bytes) of stored frames
|
||||
current_backup_size: usize,
|
||||
|
||||
/// Storage used to hold outgoing frames in case they need to be reused
|
||||
frames: VecDeque<OwnedFrame>,
|
||||
|
||||
/// Counter keeping track of total frames sent
|
||||
sent_cnt: u64,
|
||||
|
||||
/// Counter keeping track of total frames received
|
||||
received_cnt: u64,
|
||||
|
||||
/// Indicates whether the backup is frozen, which indicates that mutations are ignored
|
||||
frozen: bool,
|
||||
}
|
||||
|
||||
impl Default for Backup {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Backup {
|
||||
/// Creates a new, unfrozen backup.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_backup_size: MAX_BACKUP_SIZE,
|
||||
current_backup_size: 0,
|
||||
frames: VecDeque::new(),
|
||||
sent_cnt: 0,
|
||||
received_cnt: 0,
|
||||
frozen: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the backup of any stored data and resets the state to being new.
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Like all other modifications, this will do nothing if the backup is frozen.
|
||||
pub fn clear(&mut self) {
|
||||
if !self.frozen {
|
||||
self.current_backup_size = 0;
|
||||
self.frames.clear();
|
||||
self.sent_cnt = 0;
|
||||
self.received_cnt = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the backup is frozen, meaning that modifications will be ignored.
|
||||
#[inline]
|
||||
pub fn is_frozen(&self) -> bool {
|
||||
self.frozen
|
||||
}
|
||||
|
||||
/// Sets the frozen status.
|
||||
#[inline]
|
||||
pub fn set_frozen(&mut self, frozen: bool) {
|
||||
self.frozen = frozen;
|
||||
}
|
||||
|
||||
/// Marks the backup as frozen.
|
||||
#[inline]
|
||||
pub fn freeze(&mut self) {
|
||||
self.frozen = true;
|
||||
}
|
||||
|
||||
/// Marks the backup as no longer frozen.
|
||||
#[inline]
|
||||
pub fn unfreeze(&mut self) {
|
||||
self.frozen = false;
|
||||
}
|
||||
|
||||
/// Sets the maximum size (in bytes) of collective frames stored in case a backup is needed
|
||||
/// during reconnection. Setting the `size` to 0 will result in no frames being stored.
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Like all other modifications, this will do nothing if the backup is frozen.
|
||||
pub fn set_max_backup_size(&mut self, size: usize) {
|
||||
if !self.frozen {
|
||||
self.max_backup_size = size;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the maximum size (in bytes) of collective frames stored in case a backup is needed
|
||||
/// during reconnection.
|
||||
pub fn max_backup_size(&self) -> usize {
|
||||
self.max_backup_size
|
||||
}
|
||||
|
||||
/// Increments (by 1) the total sent frames.
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Like all other modifications, this will do nothing if the backup is frozen.
|
||||
pub(crate) fn increment_sent_cnt(&mut self) {
|
||||
if !self.frozen {
|
||||
self.sent_cnt += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns how many frames have been sent.
|
||||
pub(crate) fn sent_cnt(&self) -> u64 {
|
||||
self.sent_cnt
|
||||
}
|
||||
|
||||
/// Increments (by 1) the total received frames.
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Like all other modifications, this will do nothing if the backup is frozen.
|
||||
pub(super) fn increment_received_cnt(&mut self) {
|
||||
if !self.frozen {
|
||||
self.received_cnt += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns how many frames have been received.
|
||||
pub(crate) fn received_cnt(&self) -> u64 {
|
||||
self.received_cnt
|
||||
}
|
||||
|
||||
/// Sets the total received frames to the specified `cnt`.
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Like all other modifications, this will do nothing if the backup is frozen.
|
||||
pub(super) fn set_received_cnt(&mut self, cnt: u64) {
|
||||
if !self.frozen {
|
||||
self.received_cnt = cnt;
|
||||
}
|
||||
}
|
||||
|
||||
/// Pushes a new frame to the end of the internal queue.
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Like all other modifications, this will do nothing if the backup is frozen.
|
||||
pub(crate) fn push_frame(&mut self, frame: Frame) {
|
||||
if self.max_backup_size > 0 && !self.frozen {
|
||||
self.current_backup_size += frame.len();
|
||||
self.frames.push_back(frame.into_owned());
|
||||
while self.current_backup_size > self.max_backup_size {
|
||||
match self.frames.pop_front() {
|
||||
Some(frame) => {
|
||||
self.current_backup_size -= frame.len();
|
||||
}
|
||||
|
||||
// If we have exhausted all frames, then we have reached
|
||||
// an internal size of 0 and should exit the loop
|
||||
None => {
|
||||
self.current_backup_size = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the total frames being kept for potential reuse.
|
||||
pub(super) fn frame_cnt(&self) -> usize {
|
||||
self.frames.len()
|
||||
}
|
||||
|
||||
/// Returns an iterator over the frames contained in the backup.
|
||||
pub(super) fn frames(&self) -> impl Iterator<Item = &Frame> {
|
||||
self.frames.iter()
|
||||
}
|
||||
|
||||
/// Truncates the stored frames to be no larger than `size` total frames by popping from the
|
||||
/// front rather than the back of the list.
|
||||
///
|
||||
/// ### Note
|
||||
///
|
||||
/// Like all other modifications, this will do nothing if the backup is frozen.
|
||||
pub(super) fn truncate_front(&mut self, size: usize) {
|
||||
if !self.frozen {
|
||||
while self.frames.len() > size {
|
||||
if let Some(frame) = self.frames.pop_front() {
|
||||
self.current_backup_size -=
|
||||
std::cmp::min(frame.len(), self.current_backup_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue