Add tests for codec and transport; move net::client to dedicated file

This commit is contained in:
Chip Senkbeil 2021-08-18 02:34:04 -05:00
parent f6e9195503
commit e857dabe43
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131
4 changed files with 781 additions and 273 deletions

227
src/core/net/client.rs Normal file
View File

@ -0,0 +1,227 @@
use crate::core::{
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, Response},
net::{DataStream, Transport, TransportError, TransportWriteHalf},
session::Session,
utils,
};
use log::*;
use std::{
collections::HashMap,
convert,
sync::{Arc, Mutex},
};
use tokio::{
io,
net::TcpStream,
sync::{broadcast, oneshot},
task::{JoinError, JoinHandle},
time::Duration,
};
use tokio_stream::wrappers::BroadcastStream;
type Callbacks = Arc<Mutex<HashMap<usize, oneshot::Sender<Response>>>>;
/// Represents a client that can make requests against a server
pub struct Client<T>
where
T: DataStream,
{
/// Underlying transport used by client
t_write: TransportWriteHalf<T::Write>,
/// Collection of callbacks to be invoked upon receiving a response to a request
callbacks: Callbacks,
/// Callback to trigger when a response is received without an origin or with an origin
/// not found in the list of callbacks
broadcast: broadcast::Sender<Response>,
/// Represents an initial receiver for broadcasted responses that can capture responses
/// prior to a stream being established and consumed
init_broadcast_receiver: Option<broadcast::Receiver<Response>>,
/// Contains the task that is running to receive responses from a server
response_task: JoinHandle<()>,
}
impl Client<TcpStream> {
/// Connect to a remote TCP session
pub async fn tcp_connect(session: Session) -> io::Result<Self> {
let transport = Transport::<TcpStream>::connect(session).await?;
debug!(
"Client has connected to {}",
transport
.peer_addr()
.map(|x| x.to_string())
.unwrap_or_else(|_| String::from("???"))
);
Self::inner_connect(transport).await
}
/// Connect to a remote TCP session, timing out after duration has passed
pub async fn tcp_connect_timeout(session: Session, duration: Duration) -> io::Result<Self> {
utils::timeout(duration, Self::tcp_connect(session))
.await
.and_then(convert::identity)
}
}
#[cfg(unix)]
impl Client<tokio::net::UnixStream> {
/// Connect to a proxy unix socket
pub async fn unix_connect(
path: impl AsRef<std::path::Path>,
auth_key: Option<Arc<orion::aead::SecretKey>>,
) -> io::Result<Self> {
let transport = Transport::<tokio::net::UnixStream>::connect(path, auth_key).await?;
debug!(
"Client has connected to {}",
transport
.peer_addr()
.map(|x| format!("{:?}", x))
.unwrap_or_else(|_| String::from("???"))
);
Self::inner_connect(transport).await
}
/// Connect to a proxy unix socket, timing out after duration has passed
pub async fn unix_connect_timeout(
path: impl AsRef<std::path::Path>,
auth_key: Option<Arc<orion::aead::SecretKey>>,
duration: Duration,
) -> io::Result<Self> {
utils::timeout(duration, Self::unix_connect(path, auth_key))
.await
.and_then(convert::identity)
}
}
impl<T> Client<T>
where
T: DataStream,
{
/// Establishes a connection using the provided session
async fn inner_connect(transport: Transport<T>) -> io::Result<Self> {
let (mut t_read, t_write) = transport.into_split();
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
let (broadcast, init_broadcast_receiver) =
broadcast::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
// Start a task that continually checks for responses and triggers callbacks
let callbacks_2 = Arc::clone(&callbacks);
let broadcast_2 = broadcast.clone();
let response_task = tokio::spawn(async move {
loop {
match t_read.receive::<Response>().await {
Ok(Some(res)) => {
trace!("Client got response: {:?}", res);
let maybe_callback = res
.origin_id
.as_ref()
.and_then(|id| callbacks_2.lock().unwrap().remove(id));
// If there is an origin to this response, trigger the callback
if let Some(tx) = maybe_callback {
trace!("Client has callback! Triggering!");
if let Err(res) = tx.send(res) {
error!("Failed to trigger callback for response {}", res.id);
}
// Otherwise, this goes into the junk draw of response handlers
} else {
trace!("Client does not have callback! Broadcasting!");
if let Err(x) = broadcast_2.send(res) {
error!("Failed to trigger broadcast: {}", x);
}
}
}
Ok(None) => break,
Err(x) => {
error!("{}", x);
break;
}
}
}
});
Ok(Self {
t_write,
callbacks,
broadcast,
init_broadcast_receiver: Some(init_broadcast_receiver),
response_task,
})
}
/// Waits for the client to terminate, which results when the receiving end of the network
/// connection is closed (or the client is shutdown)
pub async fn wait(self) -> Result<(), JoinError> {
self.response_task.await
}
/// Abort the client's current connection by forcing its response task to shutdown
pub fn abort(&self) {
self.response_task.abort()
}
/// Sends a request and waits for a response
pub async fn send(&mut self, req: Request) -> Result<Response, TransportError> {
// First, add a callback that will trigger when we get the response for this request
let (tx, rx) = oneshot::channel();
self.callbacks.lock().unwrap().insert(req.id, tx);
// Second, send the request
self.t_write.send(req).await?;
// Third, wait for the response
rx.await
.map_err(|x| TransportError::from(io::Error::new(io::ErrorKind::ConnectionAborted, x)))
}
/// Sends a request and waits for a response, timing out after duration has passed
pub async fn send_timeout(
&mut self,
req: Request,
duration: Duration,
) -> Result<Response, TransportError> {
utils::timeout(duration, self.send(req))
.await
.map_err(TransportError::from)
.and_then(convert::identity)
}
/// Sends a request without waiting for a response
///
/// Any response that would be received gets sent over the broadcast channel instead
pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> {
self.t_write.send(req).await
}
/// Sends a request without waiting for a response, timing out after duration has passed
pub async fn fire_timeout(
&mut self,
req: Request,
duration: Duration,
) -> Result<(), TransportError> {
utils::timeout(duration, self.fire(req))
.await
.map_err(TransportError::from)
.and_then(convert::identity)
}
/// Clones a new instance of the broadcaster used by the client
pub fn to_response_broadcaster(&self) -> broadcast::Sender<Response> {
self.broadcast.clone()
}
/// Creates and returns a new stream of responses that are received that do not match the
/// response to a `send` request
pub fn to_response_broadcast_stream(&mut self) -> BroadcastStream<Response> {
BroadcastStream::new(
self.init_broadcast_receiver
.take()
.unwrap_or_else(|| self.broadcast.subscribe()),
)
}
}

View File

@ -1,213 +1,5 @@
mod transport;
pub use transport::{DataStream, Transport, TransportError, TransportReadHalf, TransportWriteHalf};
use crate::core::{
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, Response},
session::Session,
utils,
};
use log::*;
use std::{
collections::HashMap,
convert,
sync::{Arc, Mutex},
};
use tokio::{
io,
net::TcpStream,
sync::{broadcast, oneshot},
time::Duration,
};
use tokio_stream::wrappers::BroadcastStream;
type Callbacks = Arc<Mutex<HashMap<usize, oneshot::Sender<Response>>>>;
/// Represents a client that can make requests against a server
pub struct Client<T>
where
T: DataStream,
{
/// Underlying transport used by client
t_write: TransportWriteHalf<T::Write>,
/// Collection of callbacks to be invoked upon receiving a response to a request
callbacks: Callbacks,
/// Callback to trigger when a response is received without an origin or with an origin
/// not found in the list of callbacks
broadcast: broadcast::Sender<Response>,
/// Represents an initial receiver for broadcasted responses that can capture responses
/// prior to a stream being established and consumed
init_broadcast_receiver: Option<broadcast::Receiver<Response>>,
}
impl Client<TcpStream> {
/// Connect to a remote TCP session
pub async fn tcp_connect(session: Session) -> io::Result<Self> {
let transport = Transport::<TcpStream>::connect(session).await?;
debug!(
"Client has connected to {}",
transport
.peer_addr()
.map(|x| x.to_string())
.unwrap_or_else(|_| String::from("???"))
);
Self::inner_connect(transport).await
}
/// Connect to a remote TCP session, timing out after duration has passed
pub async fn tcp_connect_timeout(session: Session, duration: Duration) -> io::Result<Self> {
utils::timeout(duration, Self::tcp_connect(session))
.await
.and_then(convert::identity)
}
}
#[cfg(unix)]
impl Client<tokio::net::UnixStream> {
/// Connect to a proxy unix socket
pub async fn unix_connect(
path: impl AsRef<std::path::Path>,
auth_key: Option<Arc<orion::aead::SecretKey>>,
) -> io::Result<Self> {
let transport = Transport::<tokio::net::UnixStream>::connect(path, auth_key).await?;
debug!(
"Client has connected to {}",
transport
.peer_addr()
.map(|x| format!("{:?}", x))
.unwrap_or_else(|_| String::from("???"))
);
Self::inner_connect(transport).await
}
/// Connect to a proxy unix socket, timing out after duration has passed
pub async fn unix_connect_timeout(
path: impl AsRef<std::path::Path>,
auth_key: Option<Arc<orion::aead::SecretKey>>,
duration: Duration,
) -> io::Result<Self> {
utils::timeout(duration, Self::unix_connect(path, auth_key))
.await
.and_then(convert::identity)
}
}
impl<T> Client<T>
where
T: DataStream,
{
/// Establishes a connection using the provided session
async fn inner_connect(transport: Transport<T>) -> io::Result<Self> {
let (mut t_read, t_write) = transport.into_split();
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
let (broadcast, init_broadcast_receiver) =
broadcast::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
// Start a task that continually checks for responses and triggers callbacks
let callbacks_2 = Arc::clone(&callbacks);
let broadcast_2 = broadcast.clone();
tokio::spawn(async move {
loop {
match t_read.receive::<Response>().await {
Ok(Some(res)) => {
trace!("Client got response: {:?}", res);
let maybe_callback = res
.origin_id
.as_ref()
.and_then(|id| callbacks_2.lock().unwrap().remove(id));
// If there is an origin to this response, trigger the callback
if let Some(tx) = maybe_callback {
trace!("Client has callback! Triggering!");
if let Err(res) = tx.send(res) {
error!("Failed to trigger callback for response {}", res.id);
}
// Otherwise, this goes into the junk draw of response handlers
} else {
trace!("Client does not have callback! Broadcasting!");
if let Err(x) = broadcast_2.send(res) {
error!("Failed to trigger broadcast: {}", x);
}
}
}
Ok(None) => break,
Err(x) => {
error!("{}", x);
break;
}
}
}
});
Ok(Self {
t_write,
callbacks,
broadcast,
init_broadcast_receiver: Some(init_broadcast_receiver),
})
}
/// Sends a request and waits for a response
pub async fn send(&mut self, req: Request) -> Result<Response, TransportError> {
// First, add a callback that will trigger when we get the response for this request
let (tx, rx) = oneshot::channel();
self.callbacks.lock().unwrap().insert(req.id, tx);
// Second, send the request
self.t_write.send(req).await?;
// Third, wait for the response
rx.await
.map_err(|x| TransportError::from(io::Error::new(io::ErrorKind::ConnectionAborted, x)))
}
/// Sends a request and waits for a response, timing out after duration has passed
pub async fn send_timeout(
&mut self,
req: Request,
duration: Duration,
) -> Result<Response, TransportError> {
utils::timeout(duration, self.send(req))
.await
.map_err(TransportError::from)
.and_then(convert::identity)
}
/// Sends a request without waiting for a response
///
/// Any response that would be received gets sent over the broadcast channel instead
pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> {
self.t_write.send(req).await
}
/// Sends a request without waiting for a response, timing out after duration has passed
pub async fn fire_timeout(
&mut self,
req: Request,
duration: Duration,
) -> Result<(), TransportError> {
utils::timeout(duration, self.fire(req))
.await
.map_err(TransportError::from)
.and_then(convert::identity)
}
/// Clones a new instance of the broadcaster used by the client
pub fn to_response_broadcaster(&self) -> broadcast::Sender<Response> {
self.broadcast.clone()
}
/// Creates and returns a new stream of responses that are received that do not match the
/// response to a `send` request
pub fn to_response_broadcast_stream(&mut self) -> BroadcastStream<Response> {
BroadcastStream::new(
self.init_broadcast_receiver
.take()
.unwrap_or_else(|| self.broadcast.subscribe()),
)
}
}
mod client;
pub use client::Client;

View File

@ -41,6 +41,12 @@ impl Decoder for DistantCodec {
// Second, retrieve total size of our frame's message
let msg_len = u64::from_be_bytes(src[..LEN_SIZE].try_into().unwrap());
if msg_len == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Frame cannot have msg len of 0",
));
}
// Third, return our msg if it's available, stripping it of the length data
let frame_len = frame_size(msg_len as usize);
@ -56,3 +62,138 @@ impl Decoder for DistantCodec {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encoder_should_encode_byte_slice_with_frame_size() {
let mut encoder = DistantCodec;
let mut buf = BytesMut::new();
// Verify that first encoding properly includes size and data
// Format is {N as 8 bytes}{data as N bytes}
encoder.encode(&[1, 2, 3], &mut buf).unwrap();
assert_eq!(
buf,
vec![/* Size of 3 as u64 */ 0, 0, 0, 0, 0, 0, 0, 3, /* Data */ 1, 2, 3],
);
// Verify that second encoding properly adds to end of buffer and doesn't overwrite
encoder.encode(&[4, 5, 6, 7, 8, 9], &mut buf).unwrap();
assert_eq!(
buf,
vec![
/* First encoding */ 0, 0, 0, 0, 0, 0, 0, 3, 1, 2, 3,
/* Second encoding */ 0, 0, 0, 0, 0, 0, 0, 6, 4, 5, 6, 7, 8, 9,
],
);
}
#[test]
fn decoder_should_return_none_if_received_data_smaller_than_frame_length_field() {
let mut decoder = DistantCodec;
let mut buf = BytesMut::new();
// Put 1 less than frame len field size
for i in 0..LEN_SIZE {
buf.put_u8(i as u8);
}
match decoder.decode(&mut buf) {
Ok(None) => {}
x => panic!("decoder.decode(...) wanted Ok(None), but got {:?}", x),
}
}
#[test]
fn decoder_should_return_none_if_received_data_is_not_a_full_frame() {
let mut decoder = DistantCodec;
let mut buf = BytesMut::new();
// Put the length of our frame, but no frame at all
buf.put_u64(4);
match decoder.decode(&mut buf) {
Ok(None) => {}
x => panic!("decoder.decode(...) wanted Ok(None), but got {:?}", x),
}
// Put part of the frame, but not the full frame (3 out of 4 bytes)
buf.put_u8(1);
buf.put_u8(2);
buf.put_u8(3);
match decoder.decode(&mut buf) {
Ok(None) => {}
x => panic!("decoder.decode(...) wanted Ok(None), but got {:?}", x),
}
}
#[test]
fn decoder_should_decode_and_return_next_frame_if_available() {
let mut decoder = DistantCodec;
let mut buf = BytesMut::new();
// Put exactly a frame via the length and then the data
buf.put_u64(4);
buf.put_u8(1);
buf.put_u8(2);
buf.put_u8(3);
buf.put_u8(4);
match decoder.decode(&mut buf) {
Ok(Some(data)) => assert_eq!(data, [1, 2, 3, 4]),
x => panic!(
"decoder.decode(...) wanted Ok(Vec[1, 2, 3, 4]), but got {:?}",
x
),
}
}
#[test]
fn decoder_should_properly_remove_decoded_frame_from_byte_buffer() {
let mut decoder = DistantCodec;
let mut buf = BytesMut::new();
// Put exactly a frame via the length and then the data
buf.put_u64(4);
buf.put_u8(1);
buf.put_u8(2);
buf.put_u8(3);
buf.put_u8(4);
// Add a little bit more post frame
buf.put_u8(123);
match decoder.decode(&mut buf) {
Ok(Some(data)) => {
assert_eq!(data, [1, 2, 3, 4]);
assert_eq!(buf, vec![123]);
}
x => panic!(
"decoder.decode(...) wanted Ok(Vec[1, 2, 3, 4]), but got {:?}",
x
),
}
}
#[test]
fn decoder_should_return_error_if_frame_has_msg_len_of_zero() {
let mut decoder = DistantCodec;
let mut buf = BytesMut::new();
// Put a bad frame with a msg len of 0
buf.put_u64(0);
buf.put_u8(1);
match decoder.decode(&mut buf) {
Err(x) => assert_eq!(x.kind(), io::ErrorKind::InvalidData),
x => panic!(
"decoder.decode(...) wanted Err(io::ErrorKind::InvalidData), but got {:?}",
x
),
}
}
}

View File

@ -79,6 +79,88 @@ impl DataStream for net::UnixStream {
}
}
/// Sends some data across the wire, waiting for it to completely send
macro_rules! send {
($conn:expr, $crypt_key:expr, $auth_key:expr, $data:expr) => {
async {
// Serialize, encrypt, and then sign
// NOTE: Cannot used packed implementation for now due to issues with deserialization
let data = serde_cbor::to_vec(&$data)?;
let data = aead::seal(&$crypt_key, &data).map_err(TransportError::EncryptError)?;
let tag = $auth_key
.as_ref()
.map(|key| auth::authenticate(key, &data))
.transpose()
.map_err(TransportError::AuthError)?;
// Send {TAG LEN}{TAG}{ENCRYPTED DATA} if we have an auth key,
// otherwise just send the encrypted data on its own
let mut out: Vec<u8> = Vec::new();
if let Some(tag) = tag {
let tag_len = tag.unprotected_as_bytes().len() as u8;
out.push(tag_len);
out.extend_from_slice(tag.unprotected_as_bytes());
}
out.extend(data);
$conn.send(&out).await.map_err(TransportError::from)
}
};
}
macro_rules! recv {
($conn:expr, $crypt_key:expr, $auth_key:expr) => {
async {
// If data is received, we process like usual
if let Some(data) = $conn.next().await {
let mut data = data?;
if data.is_empty() {
return Err(TransportError::from(io::Error::new(
io::ErrorKind::InvalidData,
"Received data is empty",
)));
}
// Retrieve in form {TAG LEN}{TAG}{ENCRYPTED DATA}
// with the tag len and tag being optional
if let Some(auth_key) = $auth_key.as_ref() {
// Parse the tag from the length, protecting against bad lengths
let tag_len = data[0];
if data.len() <= tag_len as usize {
return Err(TransportError::from(io::Error::new(
io::ErrorKind::InvalidData,
format!("Tag len {} > Data len {}", tag_len, data.len()),
)));
}
let tag = Tag::from_slice(&data[1..=tag_len as usize])
.map_err(TransportError::AuthError)?;
// Update data with the content after the tag by mutating
// the current data to point to the return from split_off
data = data.split_off(tag_len as usize + 1);
// Validate signature, decrypt, and then deserialize
auth::authenticate_verify(&tag, auth_key, &data)
.map_err(TransportError::AuthError)?;
}
let data = aead::open(&$crypt_key, &data).map_err(TransportError::EncryptError)?;
let data = serde_cbor::from_slice(&data)?;
Ok(Some(data))
// Otherwise, if no data is received, this means that our socket has closed
} else {
Ok(None)
}
}
};
}
/// Represents a transport of data across the network
pub struct Transport<T>
where
@ -132,6 +214,15 @@ where
"Stream ended before handshake completed",
)
})??;
// If the data we received is too small, return an error
if data.len() <= SALT_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Response had size smaller than expected",
));
}
let (salt_bytes, other_public_key_bytes) = data.split_at(SALT_LEN);
let other_salt = Salt::from_slice(salt_bytes)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
@ -174,12 +265,20 @@ where
crypt_key,
})
}
}
impl<T> Transport<T>
where
T: AsyncRead + AsyncWrite + DataStream + Unpin,
{
/// Sends some data across the wire, waiting for it to completely send
#[allow(dead_code)]
pub async fn send<D: Serialize>(&mut self, data: D) -> Result<(), TransportError> {
send!(self.conn, self.crypt_key, self.auth_key.as_ref(), data).await
}
/// Receives some data from out on the wire, waiting until it's available,
/// returning none if the transport is now closed
#[allow(dead_code)]
pub async fn receive<R: DeserializeOwned>(&mut self) -> Result<Option<R>, TransportError> {
recv!(self.conn, self.crypt_key, self.auth_key).await
}
/// Splits transport into read and write halves
pub fn into_split(self) -> (TransportReadHalf<T::Read>, TransportWriteHalf<T::Write>) {
let crypt_key = self.crypt_key;
@ -266,30 +365,7 @@ where
{
/// Sends some data across the wire, waiting for it to completely send
pub async fn send<D: Serialize>(&mut self, data: D) -> Result<(), TransportError> {
// Serialize, encrypt, and then sign
// NOTE: Cannot used packed implementation for now due to issues with deserialization
let data = serde_cbor::to_vec(&data)?;
let data = aead::seal(&self.crypt_key, &data).map_err(TransportError::EncryptError)?;
let tag = self
.auth_key
.as_ref()
.map(|key| auth::authenticate(key, &data))
.transpose()
.map_err(TransportError::AuthError)?;
// Send {TAG LEN}{TAG}{ENCRYPTED DATA} if we have an auth key,
// otherwise just send the encrypted data on its own
let mut out: Vec<u8> = Vec::new();
if let Some(tag) = tag {
let tag_len = tag.unprotected_as_bytes().len() as u8;
out.push(tag_len);
out.extend_from_slice(tag.unprotected_as_bytes());
}
out.extend(data);
self.conn.send(&out).await.map_err(TransportError::from)
send!(self.conn, self.crypt_key, self.auth_key.as_ref(), data).await
}
}
@ -315,49 +391,321 @@ where
/// Receives some data from out on the wire, waiting until it's available,
/// returning none if the transport is now closed
pub async fn receive<R: DeserializeOwned>(&mut self) -> Result<Option<R>, TransportError> {
// If data is received, we process like usual
if let Some(data) = self.conn.next().await {
let mut data = data?;
recv!(self.conn, self.crypt_key, self.auth_key).await
}
}
if data.is_empty() {
return Err(TransportError::from(io::Error::new(
io::ErrorKind::InvalidData,
"Received data is empty",
)));
#[cfg(test)]
mod tests {
use super::*;
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use tokio::{io::ReadBuf, sync::mpsc};
pub const TEST_DATA_STREAM_CHANNEL_BUFFER_SIZE: usize = 100;
/// Represents a data stream comprised of two inmemory buffers of data
pub struct TestDataStream {
incoming: TestDataStreamReadHalf,
outgoing: TestDataStreamWriteHalf,
}
impl TestDataStream {
pub fn new(incoming: mpsc::Receiver<Vec<u8>>, outgoing: mpsc::Sender<Vec<u8>>) -> Self {
Self {
incoming: TestDataStreamReadHalf(incoming),
outgoing: TestDataStreamWriteHalf(outgoing),
}
}
// Retrieve in form {TAG LEN}{TAG}{ENCRYPTED DATA}
// with the tag len and tag being optional
if let Some(auth_key) = self.auth_key.as_ref() {
// Parse the tag from the length, protecting against bad lengths
let tag_len = data[0];
if data.len() <= tag_len as usize {
return Err(TransportError::from(io::Error::new(
io::ErrorKind::InvalidData,
format!("Tag len {} > Data len {}", tag_len, data.len()),
)));
/// Returns (incoming_tx, outgoing_rx, stream)
pub fn make() -> (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>, Self) {
let (incoming_tx, incoming_rx) = mpsc::channel(TEST_DATA_STREAM_CHANNEL_BUFFER_SIZE);
let (outgoing_tx, outgoing_rx) = mpsc::channel(TEST_DATA_STREAM_CHANNEL_BUFFER_SIZE);
(
incoming_tx,
outgoing_rx,
Self::new(incoming_rx, outgoing_tx),
)
}
/// Returns pair of streams that are connected such that one sends to the other and
/// vice versa
pub fn pair() -> (Self, Self) {
let (tx, rx, stream) = Self::make();
(stream, Self::new(rx, tx))
}
}
impl AsyncRead for TestDataStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.incoming).poll_read(cx, buf)
}
}
impl AsyncWrite for TestDataStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.outgoing).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.outgoing).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.outgoing).poll_shutdown(cx)
}
}
pub struct TestDataStreamReadHalf(mpsc::Receiver<Vec<u8>>);
impl AsyncRead for TestDataStreamReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.0.poll_recv(cx).map(|x| match x {
Some(x) => {
buf.put_slice(&x);
Ok(())
}
None => Ok(()),
})
}
}
let tag = Tag::from_slice(&data[1..=tag_len as usize])
.map_err(TransportError::AuthError)?;
// Update data with the content after the tag by mutating
// the current data to point to the return from split_off
data = data.split_off(tag_len as usize + 1);
// Validate signature, decrypt, and then deserialize
auth::authenticate_verify(&tag, auth_key, &data)
.map_err(TransportError::AuthError)?;
pub struct TestDataStreamWriteHalf(mpsc::Sender<Vec<u8>>);
impl AsyncWrite for TestDataStreamWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.0.try_send(buf.to_vec()) {
Ok(_) => Poll::Ready(Ok(buf.len())),
Err(_) => Poll::Ready(Ok(0)),
}
}
let data = aead::open(&self.crypt_key, &data).map_err(TransportError::EncryptError)?;
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
let data = serde_cbor::from_slice(&data)?;
Ok(Some(data))
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_flush(cx)
}
}
// Otherwise, if no data is received, this means that our socket has closed
} else {
Ok(None)
impl DataStream for TestDataStream {
type Read = TestDataStreamReadHalf;
type Write = TestDataStreamWriteHalf;
fn to_connection_tag(&self) -> String {
String::from("test-stream")
}
fn into_split(self) -> (Self::Read, Self::Write) {
(self.incoming, self.outgoing)
}
}
#[tokio::test]
async fn transport_from_handshake_should_fail_if_connection_reached_eof() {
// Cause nothing left incoming to stream by _
let (_, mut rx, stream) = TestDataStream::make();
let result = Transport::from_handshake(stream, None).await;
// Verify that a salt and public key were sent out first
// 1. Frame includes an 8 byte size at beginning
// 2. Salt len + 256-bit (32 byte) public key + 1 byte tag (len) for pub key
let outgoing = rx.recv().await.unwrap();
assert_eq!(
outgoing.len(),
8 + SALT_LEN + 33,
"Unexpected outgoing data: {:?}",
outgoing
);
// Then confirm that failed because didn't receive anything back
match result {
Err(x) if x.kind() == io::ErrorKind::UnexpectedEof => {}
Err(x) => panic!("Unexpected error: {:?}", x),
Ok(_) => panic!("Unexpectedly succeeded!"),
}
}
#[tokio::test]
async fn transport_from_handshake_should_fail_if_response_data_is_too_small() {
let (tx, _rx, stream) = TestDataStream::make();
// Need SALT + PUB KEY where salt has a defined size; so, at least 1 larger than salt
// would succeed, whereas we are providing exactly salt, which will fail
{
let mut frame = Vec::new();
frame.extend_from_slice(&(SALT_LEN as u64).to_be_bytes());
frame.extend_from_slice(Salt::generate(SALT_LEN).unwrap().as_ref());
tx.send(frame).await.unwrap();
drop(tx);
}
match Transport::from_handshake(stream, None).await {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
Err(x) => panic!("Unexpected error: {:?}", x),
Ok(_) => panic!("Unexpectedly succeeded!"),
}
}
#[tokio::test]
async fn transport_from_handshake_should_fail_if_bad_foreign_public_key_received() {
let (tx, _rx, stream) = TestDataStream::make();
// Send {SALT LEN}{SALT}{PUB KEY} where public key is bad;
// normally public key bytes would be {LEN}{KEY} where len is first byte;
// if the len does not match the rest of the message len, an error will be returned
{
let mut frame = Vec::new();
frame.extend_from_slice(&((SALT_LEN + 3) as u64).to_be_bytes());
frame.extend_from_slice(Salt::generate(SALT_LEN).unwrap().as_ref());
frame.extend_from_slice(&[1, 1, 2]);
tx.send(frame).await.unwrap();
drop(tx);
}
match Transport::from_handshake(stream, None).await {
Err(x) if x.kind() == io::ErrorKind::InvalidData => {
let source = x.into_inner().expect("Inner source missing");
assert_eq!(
source.to_string(),
"crypto error",
"Unexpected source: {}",
source
);
}
Err(x) => panic!("Unexpected error: {:?}", x),
Ok(_) => panic!("Unexpectedly succeeded!"),
}
}
#[tokio::test]
async fn transport_should_be_able_to_send_encrypted_data_to_other_side_to_decrypt() {
let (src, dst) = TestDataStream::pair();
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
let (src, dst) = tokio::join!(
Transport::from_handshake(src, None),
Transport::from_handshake(dst, None)
);
let mut src = src.expect("src stream failed handshake");
let mut dst = dst.expect("dst stream failed handshake");
src.send("some data").await.expect("Failed to send data");
let data = dst
.receive::<String>()
.await
.expect("Failed to receive data")
.expect("Data missing");
assert_eq!(data, "some data");
}
#[tokio::test]
async fn transport_should_be_able_to_sign_and_validate_signature_if_auth_key_included() {
let (src, dst) = TestDataStream::pair();
let auth_key = Arc::new(SecretKey::default());
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
let (src, dst) = tokio::join!(
Transport::from_handshake(src, Some(Arc::clone(&auth_key))),
Transport::from_handshake(dst, Some(auth_key))
);
let mut src = src.expect("src stream failed handshake");
let mut dst = dst.expect("dst stream failed handshake");
src.send("some data").await.expect("Failed to send data");
let data = dst
.receive::<String>()
.await
.expect("Failed to receive data")
.expect("Data missing");
assert_eq!(data, "some data");
}
#[tokio::test]
async fn transport_receive_should_fail_if_auth_key_differs_from_other_end() {
let (src, dst) = TestDataStream::pair();
// Make two transports with different auth keys
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
let (src, dst) = tokio::join!(
Transport::from_handshake(src, Some(Arc::new(SecretKey::default()))),
Transport::from_handshake(dst, Some(Arc::new(SecretKey::default())))
);
let mut src = src.expect("src stream failed handshake");
let mut dst = dst.expect("dst stream failed handshake");
src.send("some data").await.expect("Failed to send data");
match dst.receive::<String>().await {
Err(TransportError::AuthError(_)) => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn transport_receive_should_fail_if_has_auth_key_while_sender_did_not_use_one() {
let (src, dst) = TestDataStream::pair();
// Make two transports with different auth keys
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
let (src, dst) = tokio::join!(
Transport::from_handshake(dst, None),
Transport::from_handshake(src, Some(Arc::new(SecretKey::default())))
);
let mut src = src.expect("src stream failed handshake");
let mut dst = dst.expect("dst stream failed handshake");
src.send("some data").await.expect("Failed to send data");
match dst.receive::<String>().await {
Err(TransportError::AuthError(_)) => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[tokio::test]
async fn transport_receive_should_fail_if_has_no_auth_key_while_sender_used_one() {
let (src, dst) = TestDataStream::pair();
// Make two transports with different auth keys
// NOTE: This is slow during tests as it is an expensive process and we're doing it twice!
let (src, dst) = tokio::join!(
Transport::from_handshake(src, Some(Arc::new(SecretKey::default()))),
Transport::from_handshake(dst, None)
);
let mut src = src.expect("src stream failed handshake");
let mut dst = dst.expect("dst stream failed handshake");
src.send("some data").await.expect("Failed to send data");
match dst.receive::<String>().await {
Err(TransportError::EncryptError(_)) => {}
x => panic!("Unexpected result: {:?}", x),
}
}
}