mirror of
https://github.com/chipsenkbeil/distant.git
synced 2024-11-05 12:00:36 +00:00
Add tests for codec and transport; move net::client to dedicated file
This commit is contained in:
parent
f6e9195503
commit
e857dabe43
227
src/core/net/client.rs
Normal file
227
src/core/net/client.rs
Normal file
@ -0,0 +1,227 @@
|
||||
use crate::core::{
|
||||
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
|
||||
data::{Request, Response},
|
||||
net::{DataStream, Transport, TransportError, TransportWriteHalf},
|
||||
session::Session,
|
||||
utils,
|
||||
};
|
||||
use log::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use tokio::{
|
||||
io,
|
||||
net::TcpStream,
|
||||
sync::{broadcast, oneshot},
|
||||
task::{JoinError, JoinHandle},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
type Callbacks = Arc<Mutex<HashMap<usize, oneshot::Sender<Response>>>>;
|
||||
|
||||
/// Represents a client that can make requests against a server
|
||||
pub struct Client<T>
|
||||
where
|
||||
T: DataStream,
|
||||
{
|
||||
/// Underlying transport used by client
|
||||
t_write: TransportWriteHalf<T::Write>,
|
||||
|
||||
/// Collection of callbacks to be invoked upon receiving a response to a request
|
||||
callbacks: Callbacks,
|
||||
|
||||
/// Callback to trigger when a response is received without an origin or with an origin
|
||||
/// not found in the list of callbacks
|
||||
broadcast: broadcast::Sender<Response>,
|
||||
|
||||
/// Represents an initial receiver for broadcasted responses that can capture responses
|
||||
/// prior to a stream being established and consumed
|
||||
init_broadcast_receiver: Option<broadcast::Receiver<Response>>,
|
||||
|
||||
/// Contains the task that is running to receive responses from a server
|
||||
response_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Client<TcpStream> {
|
||||
/// Connect to a remote TCP session
|
||||
pub async fn tcp_connect(session: Session) -> io::Result<Self> {
|
||||
let transport = Transport::<TcpStream>::connect(session).await?;
|
||||
debug!(
|
||||
"Client has connected to {}",
|
||||
transport
|
||||
.peer_addr()
|
||||
.map(|x| x.to_string())
|
||||
.unwrap_or_else(|_| String::from("???"))
|
||||
);
|
||||
Self::inner_connect(transport).await
|
||||
}
|
||||
|
||||
/// Connect to a remote TCP session, timing out after duration has passed
|
||||
pub async fn tcp_connect_timeout(session: Session, duration: Duration) -> io::Result<Self> {
|
||||
utils::timeout(duration, Self::tcp_connect(session))
|
||||
.await
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl Client<tokio::net::UnixStream> {
|
||||
/// Connect to a proxy unix socket
|
||||
pub async fn unix_connect(
|
||||
path: impl AsRef<std::path::Path>,
|
||||
auth_key: Option<Arc<orion::aead::SecretKey>>,
|
||||
) -> io::Result<Self> {
|
||||
let transport = Transport::<tokio::net::UnixStream>::connect(path, auth_key).await?;
|
||||
debug!(
|
||||
"Client has connected to {}",
|
||||
transport
|
||||
.peer_addr()
|
||||
.map(|x| format!("{:?}", x))
|
||||
.unwrap_or_else(|_| String::from("???"))
|
||||
);
|
||||
Self::inner_connect(transport).await
|
||||
}
|
||||
|
||||
/// Connect to a proxy unix socket, timing out after duration has passed
|
||||
pub async fn unix_connect_timeout(
|
||||
path: impl AsRef<std::path::Path>,
|
||||
auth_key: Option<Arc<orion::aead::SecretKey>>,
|
||||
duration: Duration,
|
||||
) -> io::Result<Self> {
|
||||
utils::timeout(duration, Self::unix_connect(path, auth_key))
|
||||
.await
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Client<T>
|
||||
where
|
||||
T: DataStream,
|
||||
{
|
||||
/// Establishes a connection using the provided session
|
||||
async fn inner_connect(transport: Transport<T>) -> io::Result<Self> {
|
||||
let (mut t_read, t_write) = transport.into_split();
|
||||
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
|
||||
let (broadcast, init_broadcast_receiver) =
|
||||
broadcast::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
|
||||
|
||||
// Start a task that continually checks for responses and triggers callbacks
|
||||
let callbacks_2 = Arc::clone(&callbacks);
|
||||
let broadcast_2 = broadcast.clone();
|
||||
let response_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match t_read.receive::<Response>().await {
|
||||
Ok(Some(res)) => {
|
||||
trace!("Client got response: {:?}", res);
|
||||
let maybe_callback = res
|
||||
.origin_id
|
||||
.as_ref()
|
||||
.and_then(|id| callbacks_2.lock().unwrap().remove(id));
|
||||
|
||||
// If there is an origin to this response, trigger the callback
|
||||
if let Some(tx) = maybe_callback {
|
||||
trace!("Client has callback! Triggering!");
|
||||
if let Err(res) = tx.send(res) {
|
||||
error!("Failed to trigger callback for response {}", res.id);
|
||||
}
|
||||
|
||||
// Otherwise, this goes into the junk draw of response handlers
|
||||
} else {
|
||||
trace!("Client does not have callback! Broadcasting!");
|
||||
if let Err(x) = broadcast_2.send(res) {
|
||||
error!("Failed to trigger broadcast: {}", x);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(x) => {
|
||||
error!("{}", x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
t_write,
|
||||
callbacks,
|
||||
broadcast,
|
||||
init_broadcast_receiver: Some(init_broadcast_receiver),
|
||||
response_task,
|
||||
})
|
||||
}
|
||||
|
||||
/// Waits for the client to terminate, which results when the receiving end of the network
|
||||
/// connection is closed (or the client is shutdown)
|
||||
pub async fn wait(self) -> Result<(), JoinError> {
|
||||
self.response_task.await
|
||||
}
|
||||
|
||||
/// Abort the client's current connection by forcing its response task to shutdown
|
||||
pub fn abort(&self) {
|
||||
self.response_task.abort()
|
||||
}
|
||||
|
||||
/// Sends a request and waits for a response
|
||||
pub async fn send(&mut self, req: Request) -> Result<Response, TransportError> {
|
||||
// First, add a callback that will trigger when we get the response for this request
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.callbacks.lock().unwrap().insert(req.id, tx);
|
||||
|
||||
// Second, send the request
|
||||
self.t_write.send(req).await?;
|
||||
|
||||
// Third, wait for the response
|
||||
rx.await
|
||||
.map_err(|x| TransportError::from(io::Error::new(io::ErrorKind::ConnectionAborted, x)))
|
||||
}
|
||||
|
||||
/// Sends a request and waits for a response, timing out after duration has passed
|
||||
pub async fn send_timeout(
|
||||
&mut self,
|
||||
req: Request,
|
||||
duration: Duration,
|
||||
) -> Result<Response, TransportError> {
|
||||
utils::timeout(duration, self.send(req))
|
||||
.await
|
||||
.map_err(TransportError::from)
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Sends a request without waiting for a response
|
||||
///
|
||||
/// Any response that would be received gets sent over the broadcast channel instead
|
||||
pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> {
|
||||
self.t_write.send(req).await
|
||||
}
|
||||
|
||||
/// Sends a request without waiting for a response, timing out after duration has passed
|
||||
pub async fn fire_timeout(
|
||||
&mut self,
|
||||
req: Request,
|
||||
duration: Duration,
|
||||
) -> Result<(), TransportError> {
|
||||
utils::timeout(duration, self.fire(req))
|
||||
.await
|
||||
.map_err(TransportError::from)
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Clones a new instance of the broadcaster used by the client
|
||||
pub fn to_response_broadcaster(&self) -> broadcast::Sender<Response> {
|
||||
self.broadcast.clone()
|
||||
}
|
||||
|
||||
/// Creates and returns a new stream of responses that are received that do not match the
|
||||
/// response to a `send` request
|
||||
pub fn to_response_broadcast_stream(&mut self) -> BroadcastStream<Response> {
|
||||
BroadcastStream::new(
|
||||
self.init_broadcast_receiver
|
||||
.take()
|
||||
.unwrap_or_else(|| self.broadcast.subscribe()),
|
||||
)
|
||||
}
|
||||
}
|
@ -1,213 +1,5 @@
|
||||
mod transport;
|
||||
pub use transport::{DataStream, Transport, TransportError, TransportReadHalf, TransportWriteHalf};
|
||||
|
||||
use crate::core::{
|
||||
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
|
||||
data::{Request, Response},
|
||||
session::Session,
|
||||
utils,
|
||||
};
|
||||
use log::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use tokio::{
|
||||
io,
|
||||
net::TcpStream,
|
||||
sync::{broadcast, oneshot},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
type Callbacks = Arc<Mutex<HashMap<usize, oneshot::Sender<Response>>>>;
|
||||
|
||||
/// Represents a client that can make requests against a server
|
||||
pub struct Client<T>
|
||||
where
|
||||
T: DataStream,
|
||||
{
|
||||
/// Underlying transport used by client
|
||||
t_write: TransportWriteHalf<T::Write>,
|
||||
|
||||
/// Collection of callbacks to be invoked upon receiving a response to a request
|
||||
callbacks: Callbacks,
|
||||
|
||||
/// Callback to trigger when a response is received without an origin or with an origin
|
||||
/// not found in the list of callbacks
|
||||
broadcast: broadcast::Sender<Response>,
|
||||
|
||||
/// Represents an initial receiver for broadcasted responses that can capture responses
|
||||
/// prior to a stream being established and consumed
|
||||
init_broadcast_receiver: Option<broadcast::Receiver<Response>>,
|
||||
}
|
||||
|
||||
impl Client<TcpStream> {
|
||||
/// Connect to a remote TCP session
|
||||
pub async fn tcp_connect(session: Session) -> io::Result<Self> {
|
||||
let transport = Transport::<TcpStream>::connect(session).await?;
|
||||
debug!(
|
||||
"Client has connected to {}",
|
||||
transport
|
||||
.peer_addr()
|
||||
.map(|x| x.to_string())
|
||||
.unwrap_or_else(|_| String::from("???"))
|
||||
);
|
||||
Self::inner_connect(transport).await
|
||||
}
|
||||
|
||||
/// Connect to a remote TCP session, timing out after duration has passed
|
||||
pub async fn tcp_connect_timeout(session: Session, duration: Duration) -> io::Result<Self> {
|
||||
utils::timeout(duration, Self::tcp_connect(session))
|
||||
.await
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl Client<tokio::net::UnixStream> {
|
||||
/// Connect to a proxy unix socket
|
||||
pub async fn unix_connect(
|
||||
path: impl AsRef<std::path::Path>,
|
||||
auth_key: Option<Arc<orion::aead::SecretKey>>,
|
||||
) -> io::Result<Self> {
|
||||
let transport = Transport::<tokio::net::UnixStream>::connect(path, auth_key).await?;
|
||||
debug!(
|
||||
"Client has connected to {}",
|
||||
transport
|
||||
.peer_addr()
|
||||
.map(|x| format!("{:?}", x))
|
||||
.unwrap_or_else(|_| String::from("???"))
|
||||
);
|
||||
Self::inner_connect(transport).await
|
||||
}
|
||||
|
||||
/// Connect to a proxy unix socket, timing out after duration has passed
|
||||
pub async fn unix_connect_timeout(
|
||||
path: impl AsRef<std::path::Path>,
|
||||
auth_key: Option<Arc<orion::aead::SecretKey>>,
|
||||
duration: Duration,
|
||||
) -> io::Result<Self> {
|
||||
utils::timeout(duration, Self::unix_connect(path, auth_key))
|
||||
.await
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Client<T>
|
||||
where
|
||||
T: DataStream,
|
||||
{
|
||||
/// Establishes a connection using the provided session
|
||||
async fn inner_connect(transport: Transport<T>) -> io::Result<Self> {
|
||||
let (mut t_read, t_write) = transport.into_split();
|
||||
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
|
||||
let (broadcast, init_broadcast_receiver) =
|
||||
broadcast::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
|
||||
|
||||
// Start a task that continually checks for responses and triggers callbacks
|
||||
let callbacks_2 = Arc::clone(&callbacks);
|
||||
let broadcast_2 = broadcast.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match t_read.receive::<Response>().await {
|
||||
Ok(Some(res)) => {
|
||||
trace!("Client got response: {:?}", res);
|
||||
let maybe_callback = res
|
||||
.origin_id
|
||||
.as_ref()
|
||||
.and_then(|id| callbacks_2.lock().unwrap().remove(id));
|
||||
|
||||
// If there is an origin to this response, trigger the callback
|
||||
if let Some(tx) = maybe_callback {
|
||||
trace!("Client has callback! Triggering!");
|
||||
if let Err(res) = tx.send(res) {
|
||||
error!("Failed to trigger callback for response {}", res.id);
|
||||
}
|
||||
|
||||
// Otherwise, this goes into the junk draw of response handlers
|
||||
} else {
|
||||
trace!("Client does not have callback! Broadcasting!");
|
||||
if let Err(x) = broadcast_2.send(res) {
|
||||
error!("Failed to trigger broadcast: {}", x);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(x) => {
|
||||
error!("{}", x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
t_write,
|
||||
callbacks,
|
||||
broadcast,
|
||||
init_broadcast_receiver: Some(init_broadcast_receiver),
|
||||
})
|
||||
}
|
||||
|
||||
/// Sends a request and waits for a response
|
||||
pub async fn send(&mut self, req: Request) -> Result<Response, TransportError> {
|
||||
// First, add a callback that will trigger when we get the response for this request
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.callbacks.lock().unwrap().insert(req.id, tx);
|
||||
|
||||
// Second, send the request
|
||||
self.t_write.send(req).await?;
|
||||
|
||||
// Third, wait for the response
|
||||
rx.await
|
||||
.map_err(|x| TransportError::from(io::Error::new(io::ErrorKind::ConnectionAborted, x)))
|
||||
}
|
||||
|
||||
/// Sends a request and waits for a response, timing out after duration has passed
|
||||
pub async fn send_timeout(
|
||||
&mut self,
|
||||
req: Request,
|
||||
duration: Duration,
|
||||
) -> Result<Response, TransportError> {
|
||||
utils::timeout(duration, self.send(req))
|
||||
.await
|
||||
.map_err(TransportError::from)
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Sends a request without waiting for a response
|
||||
///
|
||||
/// Any response that would be received gets sent over the broadcast channel instead
|
||||
pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> {
|
||||
self.t_write.send(req).await
|
||||
}
|
||||
|
||||
/// Sends a request without waiting for a response, timing out after duration has passed
|
||||
pub async fn fire_timeout(
|
||||
&mut self,
|
||||
req: Request,
|
||||
duration: Duration,
|
||||
) -> Result<(), TransportError> {
|
||||
utils::timeout(duration, self.fire(req))
|
||||
.await
|
||||
.map_err(TransportError::from)
|
||||
.and_then(convert::identity)
|
||||
}
|
||||
|
||||
/// Clones a new instance of the broadcaster used by the client
|
||||
pub fn to_response_broadcaster(&self) -> broadcast::Sender<Response> {
|
||||
self.broadcast.clone()
|
||||
}
|
||||
|
||||
/// Creates and returns a new stream of responses that are received that do not match the
|
||||
/// response to a `send` request
|
||||
pub fn to_response_broadcast_stream(&mut self) -> BroadcastStream<Response> {
|
||||
BroadcastStream::new(
|
||||
self.init_broadcast_receiver
|
||||
.take()
|
||||
.unwrap_or_else(|| self.broadcast.subscribe()),
|
||||
)
|
||||
}
|
||||
}
|
||||
mod client;
|
||||
pub use client::Client;
|
||||
|
@ -41,6 +41,12 @@ impl Decoder for DistantCodec {
|
||||
|
||||
// Second, retrieve total size of our frame's message
|
||||
let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap());
|
||||
if msg_len == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Frame cannot have msg len of 0",
|
||||
));
|
||||
}
|
||||
|
||||
// Third, return our msg if it's available, stripping it of the length data
|
||||
let frame_len = frame_size(msg_len as usize);
|
||||
@ -56,3 +62,138 @@ impl Decoder for DistantCodec {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encoder_should_encode_byte_slice_with_frame_size() {
|
||||
let mut encoder = DistantCodec;
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// Verify that first encoding properly includes size and data
|
||||
// Format is {N as 8 bytes}{data as N bytes}
|
||||
encoder.encode(&[1, 2, 3], &mut buf).unwrap();
|
||||
assert_eq!(
|
||||
buf,
|
||||
vec![/* Size of 3 as u64 */ 0, 0, 0, 0, 0, 0, 0, 3, /* Data */ 1, 2, 3],
|
||||
);
|
||||
|
||||
// Verify that second encoding properly adds to end of buffer and doesn't overwrite
|
||||
encoder.encode(&[4, 5, 6, 7, 8, 9], &mut buf).unwrap();
|
||||
assert_eq!(
|
||||
buf,
|
||||
vec![
|
||||
/* First encoding */ 0, 0, 0, 0, 0, 0, 0, 3, 1, 2, 3,
|
||||
/* Second encoding */ 0, 0, 0, 0, 0, 0, 0, 6, 4, 5, 6, 7, 8, 9,
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoder_should_return_none_if_received_data_smaller_than_frame_length_field() {
|
||||
let mut decoder = DistantCodec;
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// Put 1 less than frame len field size
|
||||
for i in 0..LEN_SIZE {
|
||||
buf.put_u8(i as u8);
|
||||
}
|
||||
|
||||
match decoder.decode(&mut buf) {
|
||||
Ok(None) => {}
|
||||
x => panic!("decoder.decode(...) wanted Ok(None), but got {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoder_should_return_none_if_received_data_is_not_a_full_frame() {
|
||||
let mut decoder = DistantCodec;
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// Put the length of our frame, but no frame at all
|
||||
buf.put_u64(4);
|
||||
|
||||
match decoder.decode(&mut buf) {
|
||||
Ok(None) => {}
|
||||
x => panic!("decoder.decode(...) wanted Ok(None), but got {:?}", x),
|
||||
}
|
||||
|
||||
// Put part of the frame, but not the full frame (3 out of 4 bytes)
|
||||
buf.put_u8(1);
|
||||
buf.put_u8(2);
|
||||
buf.put_u8(3);
|
||||
|
||||
match decoder.decode(&mut buf) {
|
||||
Ok(None) => {}
|
||||
x => panic!("decoder.decode(...) wanted Ok(None), but got {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoder_should_decode_and_return_next_frame_if_available() {
|
||||
let mut decoder = DistantCodec;
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// Put exactly a frame via the length and then the data
|
||||
buf.put_u64(4);
|
||||
buf.put_u8(1);
|
||||
buf.put_u8(2);
|
||||
buf.put_u8(3);
|
||||
buf.put_u8(4);
|
||||
|
||||
match decoder.decode(&mut buf) {
|
||||
Ok(Some(data)) => assert_eq!(data, [1, 2, 3, 4]),
|
||||
x => panic!(
|
||||
"decoder.decode(...) wanted Ok(Vec[1, 2, 3, 4]), but got {:?}",
|
||||
x
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoder_should_properly_remove_decoded_frame_from_byte_buffer() {
|
||||
let mut decoder = DistantCodec;
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// Put exactly a frame via the length and then the data
|
||||
buf.put_u64(4);
|
||||
buf.put_u8(1);
|
||||
buf.put_u8(2);
|
||||
buf.put_u8(3);
|
||||
buf.put_u8(4);
|
||||
|
||||
// Add a little bit more post frame
|
||||
buf.put_u8(123);
|
||||
|
||||
match decoder.decode(&mut buf) {
|
||||
Ok(Some(data)) => {
|
||||
assert_eq!(data, [1, 2, 3, 4]);
|
||||
assert_eq!(buf, vec![123]);
|
||||
}
|
||||
x => panic!(
|
||||
"decoder.decode(...) wanted Ok(Vec[1, 2, 3, 4]), but got {:?}",
|
||||
x
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoder_should_return_error_if_frame_has_msg_len_of_zero() {
|
||||
let mut decoder = DistantCodec;
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// Put a bad frame with a msg len of 0
|
||||
buf.put_u64(0);
|
||||
buf.put_u8(1);
|
||||
|
||||
match decoder.decode(&mut buf) {
|
||||
Err(x) => assert_eq!(x.kind(), io::ErrorKind::InvalidData),
|
||||
x => panic!(
|
||||
"decoder.decode(...) wanted Err(io::ErrorKind::InvalidData), but got {:?}",
|
||||
x
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -79,6 +79,88 @@ impl DataStream for net::UnixStream {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends some data across the wire, waiting for it to completely send
|
||||
macro_rules! send {
|
||||
($conn:expr, $crypt_key:expr, $auth_key:expr, $data:expr) => {
|
||||
async {
|
||||
// Serialize, encrypt, and then sign
|
||||
// NOTE: Cannot used packed implementation for now due to issues with deserialization
|
||||
let data = serde_cbor::to_vec(&$data)?;
|
||||
|
||||
let data = aead::seal(&$crypt_key, &data).map_err(TransportError::EncryptError)?;
|
||||
let tag = $auth_key
|
||||
.as_ref()
|
||||
.map(|key| auth::authenticate(key, &data))
|
||||
.transpose()
|
||||
.map_err(TransportError::AuthError)?;
|
||||
|
||||
// Send {TAG LEN}{TAG}{ENCRYPTED DATA} if we have an auth key,
|
||||
// otherwise just send the encrypted data on its own
|
||||
let mut out: Vec<u8> = Vec::new();
|
||||
if let Some(tag) = tag {
|
||||
let tag_len = tag.unprotected_as_bytes().len() as u8;
|
||||
|
||||
out.push(tag_len);
|
||||
out.extend_from_slice(tag.unprotected_as_bytes());
|
||||
}
|
||||
out.extend(data);
|
||||
|
||||
$conn.send(&out).await.map_err(TransportError::from)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! recv {
|
||||
($conn:expr, $crypt_key:expr, $auth_key:expr) => {
|
||||
async {
|
||||
// If data is received, we process like usual
|
||||
if let Some(data) = $conn.next().await {
|
||||
let mut data = data?;
|
||||
|
||||
if data.is_empty() {
|
||||
return Err(TransportError::from(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Received data is empty",
|
||||
)));
|
||||
}
|
||||
|
||||
// Retrieve in form {TAG LEN}{TAG}{ENCRYPTED DATA}
|
||||
// with the tag len and tag being optional
|
||||
if let Some(auth_key) = $auth_key.as_ref() {
|
||||
// Parse the tag from the length, protecting against bad lengths
|
||||
let tag_len = data[0];
|
||||
if data.len() <= tag_len as usize {
|
||||
return Err(TransportError::from(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Tag len {} > Data len {}", tag_len, data.len()),
|
||||
)));
|
||||
}
|
||||
|
||||
let tag = Tag::from_slice(&data[1..=tag_len as usize])
|
||||
.map_err(TransportError::AuthError)?;
|
||||
|
||||
// Update data with the content after the tag by mutating
|
||||
// the current data to point to the return from split_off
|
||||
data = data.split_off(tag_len as usize + 1);
|
||||
|
||||
// Validate signature, decrypt, and then deserialize
|
||||
auth::authenticate_verify(&tag, auth_key, &data)
|
||||
.map_err(TransportError::AuthError)?;
|
||||
}
|
||||
|
||||
let data = aead::open(&$crypt_key, &data).map_err(TransportError::EncryptError)?;
|
||||
|
||||
let data = serde_cbor::from_slice(&data)?;
|
||||
Ok(Some(data))
|
||||
|
||||
// Otherwise, if no data is received, this means that our socket has closed
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Represents a transport of data across the network
|
||||
pub struct Transport<T>
|
||||
where
|
||||
@ -132,6 +214,15 @@ where
|
||||
"Stream ended before handshake completed",
|
||||
)
|
||||
})??;
|
||||
|
||||
// If the data we received is too small, return an error
|
||||
if data.len() <= SALT_LEN {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Response had size smaller than expected",
|
||||
));
|
||||
}
|
||||
|
||||
let (salt_bytes, other_public_key_bytes) = data.split_at(SALT_LEN);
|
||||
let other_salt = Salt::from_slice(salt_bytes)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
|
||||
@ -174,12 +265,20 @@ where
|
||||
crypt_key,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Transport<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + DataStream + Unpin,
|
||||
{
|
||||
/// Sends some data across the wire, waiting for it to completely send
|
||||
#[allow(dead_code)]
|
||||
pub async fn send<D: Serialize>(&mut self, data: D) -> Result<(), TransportError> {
|
||||
send!(self.conn, self.crypt_key, self.auth_key.as_ref(), data).await
|
||||
}
|
||||
|
||||
/// Receives some data from out on the wire, waiting until it's available,
|
||||
/// returning none if the transport is now closed
|
||||
#[allow(dead_code)]
|
||||
pub async fn receive<R: DeserializeOwned>(&mut self) -> Result<Option<R>, TransportError> {
|
||||
recv!(self.conn, self.crypt_key, self.auth_key).await
|
||||
}
|
||||
|
||||
/// Splits transport into read and write halves
|
||||
pub fn into_split(self) -> (TransportReadHalf<T::Read>, TransportWriteHalf<T::Write>) {
|
||||
let crypt_key = self.crypt_key;
|
||||
@ -266,30 +365,7 @@ where
|
||||
{
|
||||
/// Sends some data across the wire, waiting for it to completely send
|
||||
pub async fn send<D: Serialize>(&mut self, data: D) -> Result<(), TransportError> {
|
||||
// Serialize, encrypt, and then sign
|
||||
// NOTE: Cannot used packed implementation for now due to issues with deserialization
|
||||
let data = serde_cbor::to_vec(&data)?;
|
||||
|
||||
let data = aead::seal(&self.crypt_key, &data).map_err(TransportError::EncryptError)?;
|
||||
let tag = self
|
||||
.auth_key
|
||||
.as_ref()
|
||||
.map(|key| auth::authenticate(key, &data))
|
||||
.transpose()
|
||||
.map_err(TransportError::AuthError)?;
|
||||
|
||||
// Send {TAG LEN}{TAG}{ENCRYPTED DATA} if we have an auth key,
|
||||
// otherwise just send the encrypted data on its own
|
||||
let mut out: Vec<u8> = Vec::new();
|
||||
if let Some(tag) = tag {
|
||||
let tag_len = tag.unprotected_as_bytes().len() as u8;
|
||||
|
||||
out.push(tag_len);
|
||||
out.extend_from_slice(tag.unprotected_as_bytes());
|
||||
}
|
||||
out.extend(data);
|
||||
|
||||
self.conn.send(&out).await.map_err(TransportError::from)
|
||||
send!(self.conn, self.crypt_key, self.auth_key.as_ref(), data).await
|
||||
}
|
||||
}
|
||||
|
||||
@ -315,49 +391,321 @@ where
|
||||
/// Receives some data from out on the wire, waiting until it's available,
|
||||
/// returning none if the transport is now closed
|
||||
pub async fn receive<R: DeserializeOwned>(&mut self) -> Result<Option<R>, TransportError> {
|
||||
// If data is received, we process like usual
|
||||
if let Some(data) = self.conn.next().await {
|
||||
let mut data = data?;
|
||||
recv!(self.conn, self.crypt_key, self.auth_key).await
|
||||
}
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
return Err(TransportError::from(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Received data is empty",
|
||||
)));
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::{io::ReadBuf, sync::mpsc};
|
||||
|
||||
pub const TEST_DATA_STREAM_CHANNEL_BUFFER_SIZE: usize = 100;
|
||||
|
||||
/// Represents a data stream comprised of two inmemory buffers of data
|
||||
pub struct TestDataStream {
|
||||
incoming: TestDataStreamReadHalf,
|
||||
outgoing: TestDataStreamWriteHalf,
|
||||
}
|
||||
|
||||
impl TestDataStream {
|
||||
pub fn new(incoming: mpsc::Receiver<Vec<u8>>, outgoing: mpsc::Sender<Vec<u8>>) -> Self {
|
||||
Self {
|
||||
incoming: TestDataStreamReadHalf(incoming),
|
||||
outgoing: TestDataStreamWriteHalf(outgoing),
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve in form {TAG LEN}{TAG}{ENCRYPTED DATA}
|
||||
// with the tag len and tag being optional
|
||||
if let Some(auth_key) = self.auth_key.as_ref() {
|
||||
// Parse the tag from the length, protecting against bad lengths
|
||||
let tag_len = data[0];
|
||||
if data.len() <= tag_len as usize {
|
||||
return Err(TransportError::from(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Tag len {} > Data len {}", tag_len, data.len()),
|
||||
)));
|
||||
/// Returns (incoming_tx, outgoing_rx, stream)
|
||||
pub fn make() -> (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>, Self) {
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(TEST_DATA_STREAM_CHANNEL_BUFFER_SIZE);
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::channel(TEST_DATA_STREAM_CHANNEL_BUFFER_SIZE);
|
||||
|
||||
(
|
||||
incoming_tx,
|
||||
outgoing_rx,
|
||||
Self::new(incoming_rx, outgoing_tx),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns pair of streams that are connected such that one sends to the other and
|
||||
/// vice versa
|
||||
pub fn pair() -> (Self, Self) {
|
||||
let (tx, rx, stream) = Self::make();
|
||||
(stream, Self::new(rx, tx))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TestDataStream {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.incoming).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TestDataStream {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.outgoing).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.outgoing).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.outgoing).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TestDataStreamReadHalf(mpsc::Receiver<Vec<u8>>);
|
||||
impl AsyncRead for TestDataStreamReadHalf {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.0.poll_recv(cx).map(|x| match x {
|
||||
Some(x) => {
|
||||
buf.put_slice(&x);
|
||||
Ok(())
|
||||
}
|
||||
None => Ok(()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let tag = Tag::from_slice(&data[1..=tag_len as usize])
|
||||
.map_err(TransportError::AuthError)?;
|
||||
|
||||
// Update data with the content after the tag by mutating
|
||||
// the current data to point to the return from split_off
|
||||
data = data.split_off(tag_len as usize + 1);
|
||||
|
||||
// Validate signature, decrypt, and then deserialize
|
||||
auth::authenticate_verify(&tag, auth_key, &data)
|
||||
.map_err(TransportError::AuthError)?;
|
||||
pub struct TestDataStreamWriteHalf(mpsc::Sender<Vec<u8>>);
|
||||
impl AsyncWrite for TestDataStreamWriteHalf {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match self.0.try_send(buf.to_vec()) {
|
||||
Ok(_) => Poll::Ready(Ok(buf.len())),
|
||||
Err(_) => Poll::Ready(Ok(0)),
|
||||
}
|
||||
}
|
||||
|
||||
let data = aead::open(&self.crypt_key, &data).map_err(TransportError::EncryptError)?;
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
let data = serde_cbor::from_slice(&data)?;
|
||||
Ok(Some(data))
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, if no data is received, this means that our socket has closed
|
||||
} else {
|
||||
Ok(None)
|
||||
impl DataStream for TestDataStream {
|
||||
type Read = TestDataStreamReadHalf;
|
||||
type Write = TestDataStreamWriteHalf;
|
||||
|
||||
fn to_connection_tag(&self) -> String {
|
||||
String::from("test-stream")
|
||||
}
|
||||
|
||||
fn into_split(self) -> (Self::Read, Self::Write) {
|
||||
(self.incoming, self.outgoing)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_from_handshake_should_fail_if_connection_reached_eof() {
|
||||
// Cause nothing left incoming to stream by _
|
||||
let (_, mut rx, stream) = TestDataStream::make();
|
||||
let result = Transport::from_handshake(stream, None).await;
|
||||
|
||||
// Verify that a salt and public key were sent out first
|
||||
// 1. Frame includes an 8 byte size at beginning
|
||||
// 2. Salt len + 256-bit (32 byte) public key + 1 byte tag (len) for pub key
|
||||
let outgoing = rx.recv().await.unwrap();
|
||||
assert_eq!(
|
||||
outgoing.len(),
|
||||
8 + SALT_LEN + 33,
|
||||
"Unexpected outgoing data: {:?}",
|
||||
outgoing
|
||||
);
|
||||
|
||||
// Then confirm that failed because didn't receive anything back
|
||||
match result {
|
||||
Err(x) if x.kind() == io::ErrorKind::UnexpectedEof => {}
|
||||
Err(x) => panic!("Unexpected error: {:?}", x),
|
||||
Ok(_) => panic!("Unexpectedly succeeded!"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_from_handshake_should_fail_if_response_data_is_too_small() {
|
||||
let (tx, _rx, stream) = TestDataStream::make();
|
||||
|
||||
// Need SALT + PUB KEY where salt has a defined size; so, at least 1 larger than salt
|
||||
// would succeed, whereas we are providing exactly salt, which will fail
|
||||
{
|
||||
let mut frame = Vec::new();
|
||||
frame.extend_from_slice(&(SALT_LEN as u64).to_be_bytes());
|
||||
frame.extend_from_slice(Salt::generate(SALT_LEN).unwrap().as_ref());
|
||||
tx.send(frame).await.unwrap();
|
||||
drop(tx);
|
||||
}
|
||||
|
||||
match Transport::from_handshake(stream, None).await {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
|
||||
Err(x) => panic!("Unexpected error: {:?}", x),
|
||||
Ok(_) => panic!("Unexpectedly succeeded!"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_from_handshake_should_fail_if_bad_foreign_public_key_received() {
|
||||
let (tx, _rx, stream) = TestDataStream::make();
|
||||
|
||||
// Send {SALT LEN}{SALT}{PUB KEY} where public key is bad;
|
||||
// normally public key bytes would be {LEN}{KEY} where len is first byte;
|
||||
// if the len does not match the rest of the message len, an error will be returned
|
||||
{
|
||||
let mut frame = Vec::new();
|
||||
frame.extend_from_slice(&((SALT_LEN + 3) as u64).to_be_bytes());
|
||||
frame.extend_from_slice(Salt::generate(SALT_LEN).unwrap().as_ref());
|
||||
frame.extend_from_slice(&[1, 1, 2]);
|
||||
tx.send(frame).await.unwrap();
|
||||
drop(tx);
|
||||
}
|
||||
|
||||
match Transport::from_handshake(stream, None).await {
|
||||
Err(x) if x.kind() == io::ErrorKind::InvalidData => {
|
||||
let source = x.into_inner().expect("Inner source missing");
|
||||
assert_eq!(
|
||||
source.to_string(),
|
||||
"crypto error",
|
||||
"Unexpected source: {}",
|
||||
source
|
||||
);
|
||||
}
|
||||
Err(x) => panic!("Unexpected error: {:?}", x),
|
||||
Ok(_) => panic!("Unexpectedly succeeded!"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_should_be_able_to_send_encrypted_data_to_other_side_to_decrypt() {
|
||||
let (src, dst) = TestDataStream::pair();
|
||||
|
||||
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
|
||||
let (src, dst) = tokio::join!(
|
||||
Transport::from_handshake(src, None),
|
||||
Transport::from_handshake(dst, None)
|
||||
);
|
||||
|
||||
let mut src = src.expect("src stream failed handshake");
|
||||
let mut dst = dst.expect("dst stream failed handshake");
|
||||
|
||||
src.send("some data").await.expect("Failed to send data");
|
||||
let data = dst
|
||||
.receive::<String>()
|
||||
.await
|
||||
.expect("Failed to receive data")
|
||||
.expect("Data missing");
|
||||
|
||||
assert_eq!(data, "some data");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_should_be_able_to_sign_and_validate_signature_if_auth_key_included() {
|
||||
let (src, dst) = TestDataStream::pair();
|
||||
|
||||
let auth_key = Arc::new(SecretKey::default());
|
||||
|
||||
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
|
||||
let (src, dst) = tokio::join!(
|
||||
Transport::from_handshake(src, Some(Arc::clone(&auth_key))),
|
||||
Transport::from_handshake(dst, Some(auth_key))
|
||||
);
|
||||
|
||||
let mut src = src.expect("src stream failed handshake");
|
||||
let mut dst = dst.expect("dst stream failed handshake");
|
||||
|
||||
src.send("some data").await.expect("Failed to send data");
|
||||
let data = dst
|
||||
.receive::<String>()
|
||||
.await
|
||||
.expect("Failed to receive data")
|
||||
.expect("Data missing");
|
||||
|
||||
assert_eq!(data, "some data");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_receive_should_fail_if_auth_key_differs_from_other_end() {
|
||||
let (src, dst) = TestDataStream::pair();
|
||||
|
||||
// Make two transports with different auth keys
|
||||
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
|
||||
let (src, dst) = tokio::join!(
|
||||
Transport::from_handshake(src, Some(Arc::new(SecretKey::default()))),
|
||||
Transport::from_handshake(dst, Some(Arc::new(SecretKey::default())))
|
||||
);
|
||||
|
||||
let mut src = src.expect("src stream failed handshake");
|
||||
let mut dst = dst.expect("dst stream failed handshake");
|
||||
|
||||
src.send("some data").await.expect("Failed to send data");
|
||||
match dst.receive::<String>().await {
|
||||
Err(TransportError::AuthError(_)) => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_receive_should_fail_if_has_auth_key_while_sender_did_not_use_one() {
|
||||
let (src, dst) = TestDataStream::pair();
|
||||
|
||||
// Make two transports with different auth keys
|
||||
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
|
||||
let (src, dst) = tokio::join!(
|
||||
Transport::from_handshake(dst, None),
|
||||
Transport::from_handshake(src, Some(Arc::new(SecretKey::default())))
|
||||
);
|
||||
|
||||
let mut src = src.expect("src stream failed handshake");
|
||||
let mut dst = dst.expect("dst stream failed handshake");
|
||||
|
||||
src.send("some data").await.expect("Failed to send data");
|
||||
match dst.receive::<String>().await {
|
||||
Err(TransportError::AuthError(_)) => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_receive_should_fail_if_has_no_auth_key_while_sender_used_one() {
|
||||
let (src, dst) = TestDataStream::pair();
|
||||
|
||||
// Make two transports with different auth keys
|
||||
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
|
||||
let (src, dst) = tokio::join!(
|
||||
Transport::from_handshake(src, Some(Arc::new(SecretKey::default()))),
|
||||
Transport::from_handshake(dst, None)
|
||||
);
|
||||
|
||||
let mut src = src.expect("src stream failed handshake");
|
||||
let mut dst = dst.expect("dst stream failed handshake");
|
||||
|
||||
src.send("some data").await.expect("Failed to send data");
|
||||
match dst.receive::<String>().await {
|
||||
Err(TransportError::EncryptError(_)) => {}
|
||||
x => panic!("Unexpected result: {:?}", x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user