Progress towards stateful framed transport

pull/146/head
Chip Senkbeil 2 years ago
parent a2ec96e556
commit e95370589f
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -2,7 +2,6 @@ mod any;
// mod auth;
mod client;
mod id;
mod key;
mod listener;
mod packet;
mod port;
@ -14,7 +13,6 @@ pub use any::*;
// pub use auth::*;
pub use client::*;
pub use id::*;
pub use key::*;
pub use listener::*;
pub use packet::*;
pub use port::*;

@ -12,6 +12,9 @@ pub use inmemory::*;
mod tcp;
pub use tcp::*;
mod stateful;
pub use stateful::*;
#[cfg(test)]
mod test;

@ -1,7 +1,7 @@
use super::{Interest, Ready, Reconnectable, Transport};
use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use std::{fmt, io, sync::Arc};
use std::{fmt, io};
mod codec;
pub use codec::*;
@ -9,9 +9,6 @@ pub use codec::*;
mod frame;
pub use frame::*;
mod handshake;
pub use handshake::*;
/// By default, framed transport's initial capacity (and max single-read) will be 8 KiB
const DEFAULT_CAPACITY: usize = 8 * 1024;
@ -24,42 +21,23 @@ const DEFAULT_CAPACITY: usize = 8 * 1024;
pub struct FramedTransport<T, const CAPACITY: usize = DEFAULT_CAPACITY> {
inner: T,
codec: BoxedCodec,
handshake: Handshake,
incoming: BytesMut,
outgoing: BytesMut,
}
impl<T, const CAPACITY: usize> FramedTransport<T, CAPACITY> {
fn new(inner: T, codec: BoxedCodec, handshake: Handshake) -> Self {
pub fn new(inner: T, codec: BoxedCodec) -> Self {
Self {
inner,
codec,
handshake,
incoming: BytesMut::with_capacity(CAPACITY),
outgoing: BytesMut::with_capacity(CAPACITY),
}
}
/// Creates a new [`FramedTransport`] using the [`PlainCodec`]
fn plain(inner: T, handshake: Handshake) -> Self {
Self::new(inner, Box::new(PlainCodec::new()), handshake)
}
/// Performs a handshake with the other side of the `transport` in order to determine which
/// [`Codec`] to use as well as perform any additional logic to prepare the framed transport.
///
/// Will use the handshake criteria provided in `handshake`
pub async fn from_handshake(
transport: T,
handshake: Handshake,
) -> io::Result<FramedTransport<T, CAPACITY>>
where
T: Transport,
{
let mut transport = Self::plain(transport, handshake);
handshake::do_handshake(&mut transport).await?;
Ok(transport)
pub fn plain(inner: T) -> Self {
Self::new(inner, Box::new(PlainCodec::new()))
}
/// Replaces the current codec with the provided codec. Note that any bytes in the incoming or
@ -254,14 +232,7 @@ where
T: Transport + Send + Sync,
{
async fn reconnect(&mut self) -> io::Result<()> {
// Establish a new connection
Reconnectable::reconnect(&mut self.inner).await?;
// Perform handshake again, which can result in the underlying codec
// changing based on the exchange; so, we want to clear out any lingering
// bytes in the incoming and outgoing queues
self.clear();
handshake::do_handshake(self).await
Reconnectable::reconnect(&mut self.inner).await
}
}
@ -277,25 +248,8 @@ impl<const CAPACITY: usize> FramedTransport<super::InmemoryTransport, CAPACITY>
FramedTransport<super::InmemoryTransport, CAPACITY>,
) {
let (a, b) = super::InmemoryTransport::pair(buffer);
let a = FramedTransport::new(
a,
Box::new(PlainCodec::new()),
Handshake::Client {
key: HeapSecretKey::from(Vec::new()),
preferred_compression_type: None,
preferred_compression_level: None,
preferred_encryption_type: None,
},
);
let b = FramedTransport::new(
b,
Box::new(PlainCodec::new()),
Handshake::Server {
key: HeapSecretKey::from(Vec::new()),
compression_types: Vec::new(),
encryption_types: Vec::new(),
},
);
let a = FramedTransport::new(a, Box::new(PlainCodec::new()));
let b = FramedTransport::new(b, Box::new(PlainCodec::new()));
(a, b)
}
}
@ -385,7 +339,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
assert_eq!(
transport.try_read_frame().unwrap_err().kind(),
@ -399,7 +353,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
assert_eq!(
transport.try_read_frame().unwrap_err().kind(),
@ -415,7 +369,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
assert_eq!(
transport.try_read_frame().unwrap_err().kind(),
@ -431,7 +385,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
ErrCodec,
Box::new(ErrCodec),
);
assert_eq!(
transport.try_read_frame().unwrap_err().kind(),
@ -456,7 +410,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
}
@ -475,7 +429,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::READABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
@ -495,7 +449,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
// First call will only write part of the frame and then return WouldBlock
@ -516,7 +470,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
assert_eq!(
transport
@ -535,7 +489,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
ErrCodec,
Box::new(ErrCodec),
);
assert_eq!(
transport
@ -559,7 +513,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
transport.try_write_frame(b"hello world").unwrap();
@ -593,7 +547,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
// First call will only write part of the frame and then return WouldBlock
@ -636,7 +590,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
// Set our outgoing buffer to flush
@ -657,7 +611,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
// Set our outgoing buffer to flush
@ -678,7 +632,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
// Perform flush and verify nothing happens
@ -699,7 +653,7 @@ mod tests {
f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
..Default::default()
},
OkCodec,
Box::new(OkCodec),
);
// Set our outgoing buffer to flush

@ -1,44 +0,0 @@
use super::{HandshakeClientChoice, HandshakeServerOptions};
use std::fmt;
/// Callback invoked when a client receives server options during a handshake
pub struct OnHandshakeClientChoice(
pub(super) Box<dyn Fn(HandshakeServerOptions) -> HandshakeClientChoice>,
);
impl OnHandshakeClientChoice {
/// Wraps a function `f` as a callback
pub fn new<F>(f: F) -> Self
where
F: Fn(HandshakeServerOptions) -> HandshakeClientChoice,
{
Self(Box::new(f))
}
}
impl<F> From<F> for OnHandshakeClientChoice
where
F: Fn(HandshakeServerOptions) -> HandshakeClientChoice,
{
fn from(f: F) -> Self {
Self::new(f)
}
}
impl fmt::Debug for OnHandshakeClientChoice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OnHandshakeClientChoice").finish()
}
}
impl Default for OnHandshakeClientChoice {
/// Implements choice selection that picks first available of encryption and nothing of
/// compression
fn default() -> Self {
Self::new(|options| HandshakeClientChoice {
compression: None,
compression_level: None,
encryption: options.encryption.first().copied(),
})
}
}

@ -1,44 +0,0 @@
use super::FramedTransport;
use std::{fmt, future::Future, io, pin::Pin};
/// Boxed function representing `on_handshake` callback
pub type BoxedOnHandshakeFn<T, const CAPACITY: usize> = Box<
dyn FnMut(&mut FramedTransport<T, CAPACITY>) -> Pin<Box<dyn Future<Output = io::Result<()>>>>,
>;
/// Callback invoked when a handshake occurs
pub struct OnHandshake<T, const CAPACITY: usize>(pub(super) BoxedOnHandshakeFn<T, CAPACITY>);
impl<T, const CAPACITY: usize> OnHandshake<T, CAPACITY> {
/// Wraps a function `f` as a callback for a handshake
pub fn new<F>(f: F) -> Self
where
F: FnMut(
&mut FramedTransport<T, CAPACITY>,
) -> Pin<Box<dyn Future<Output = io::Result<()>>>>,
{
Self(Box::new(f))
}
}
impl<T, F, const CAPACITY: usize> From<F> for OnHandshake<T, CAPACITY>
where
F: FnMut(&mut FramedTransport<T, CAPACITY>) -> Pin<Box<dyn Future<Output = io::Result<()>>>>,
{
fn from(f: F) -> Self {
Self::new(f)
}
}
impl<T, const CAPACITY: usize> fmt::Debug for OnHandshake<T, CAPACITY> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OnHandshake").finish()
}
}
impl<T, const CAPACITY: usize> Default for OnHandshake<T, CAPACITY> {
/// Implements handshake callback that does nothing
fn default() -> Self {
Self::new(|_| Box::pin(async { Ok(()) }))
}
}

@ -0,0 +1,81 @@
use super::{FramedTransport, HeapSecretKey, Reconnectable, Transport};
use async_trait::async_trait;
use std::io;
mod handshake;
pub use handshake::*;
#[derive(Clone, Debug)]
enum State {
NotAuthenticated,
Authenticated {
key: HeapSecretKey,
handshake_options: HandshakeOptions,
},
}
/// Represents an stateful framed transport that is capable of peforming handshakes and
/// reconnecting using an authenticated state
#[derive(Clone, Debug)]
pub struct StatefulFramedTransport<T, const CAPACITY: usize> {
inner: FramedTransport<T, CAPACITY>,
state: State,
}
impl<T, const CAPACITY: usize> StatefulFramedTransport<T, CAPACITY> {
/// Creates a new stateful framed transport that is not yet authenticated
pub fn new(inner: FramedTransport<T, CAPACITY>) -> Self {
Self {
inner,
state: State::NotAuthenticated,
}
}
/// Performs an authentication handshake, moving the state to be authenticated.
///
/// Does nothing if already authenticated
pub async fn authenticate(&mut self, handshake_options: HandshakeOptions) -> io::Result<()> {
if self.is_authenticated() {
return Ok(());
}
todo!();
}
/// Returns true if in an authenticated state
pub fn is_authenticated(&self) -> bool {
matches!(self.state, State::Authenticated { .. })
}
/// Returns a reference to the [`HandshakeOptions`] used during authentication. Returns `None`
/// if not authenticated.
pub fn handshake_options(&self) -> Option<&HandshakeOptions> {
match &self.state {
State::NotAuthenticated => None,
State::Authenticated {
handshake_options, ..
} => Some(handshake_options),
}
}
}
#[async_trait]
impl<T, const CAPACITY: usize> Reconnectable for StatefulFramedTransport<T, CAPACITY>
where
T: Transport + Send + Sync,
{
async fn reconnect(&mut self) -> io::Result<()> {
match self.state {
// If not authenticated, we simply perform a raw reconnect
State::NotAuthenticated => Reconnectable::reconnect(&mut self.inner).await,
// If authenticated, we perform a reconnect followed by re-authentication using our
// previously-derived key to skip the need to do another authentication
State::Authenticated { key, .. } => {
Reconnectable::reconnect(&mut self.inner).await?;
todo!("do handshake with key");
}
}
}
}

@ -7,12 +7,6 @@ use log::*;
use serde::{Deserialize, Serialize};
use std::io;
mod on_choice;
mod on_handshake;
pub use on_choice::*;
pub use on_handshake::*;
/// Options from the server representing available methods to configure a framed transport
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HandshakeServerOptions {
@ -35,12 +29,9 @@ pub struct HandshakeClientChoice {
/// Definition of the handshake to perform for a transport
#[derive(Clone, Debug)]
pub enum Handshake {
pub enum HandshakeOptions {
/// Indicates that the handshake is being performed from the client-side
Client {
/// Secret key to use with encryption
key: HeapSecretKey,
/// Preferred compression algorithm when presented options by server
preferred_compression_type: Option<CompressionType>,
@ -53,9 +44,6 @@ pub enum Handshake {
/// Indicates that the handshake is being performed from the server-side
Server {
/// Secret key to use with encryption
key: HeapSecretKey,
/// List of available compression algorithms for use between client and server
compression_types: Vec<CompressionType>,
@ -64,25 +52,23 @@ pub enum Handshake {
},
}
impl Handshake {
/// Creates a new client handshake definition, using `key` for encryption, providing defaults
/// for the preferred compression type, compression level, and encryption type
pub fn client(key: HeapSecretKey) -> Self {
impl HandshakeOptions {
/// Creates a new client handshake definition, providing defaults for the preferred compression
/// type, compression level, and encryption type
pub fn client() -> Self {
Self::Client {
key,
preferred_compression_type: None,
preferred_compression_level: None,
preferred_encryption_type: Some(EncryptionType::XChaCha20Poly1305),
}
}
/// Creates a new client handshake definition, using `key` for encryption, providing defaults
/// for the compression types and encryption types by including all known variants
pub fn server(key: HeapSecretKey) -> Self {
/// Creates a new server handshake definition, providing defaults for the compression types and
/// encryption types by including all known variants
pub fn server() -> Self {
Self::Server {
compression_types: CompressionType::known_variants().to_vec(),
encryption_types: EncryptionType::known_variants().to_vec(),
key,
}
}
}
@ -133,8 +119,8 @@ where
}
match transport.handshake.clone() {
Handshake::Client {
key,
HandshakeOptions::Client {
access_token,
preferred_compression_type,
preferred_compression_level,
preferred_encryption_type,
@ -159,9 +145,9 @@ where
// Transform the transport's codec to abide by the choice
debug!("[Handshake] Client updating codec based on {choice:#?}");
transform_transport(transport, choice, &key)
transform_transport(transport, choice, &access_token)
}
Handshake::Server {
HandshakeOptions::Server {
key,
compression_types,
encryption_types,

@ -1,106 +0,0 @@
use super::{Interest, Ready, Reconnectable, TypedTransport};
use async_trait::async_trait;
use std::{io, sync::Mutex};
use tokio::sync::mpsc::{
self,
error::{TryRecvError, TrySendError},
};
/// Represents a [`TypedTransport`] of data across the network that uses tokio's mpsc [`Sender`]
/// and [`Receiver`] underneath.
///
/// [`Sender`]: mpsc::Sender
/// [`Receiver`]: mpsc::Receiver
#[derive(Debug)]
pub struct InmemoryTypedTransport<T, U> {
tx: mpsc::Sender<T>,
rx: Mutex<mpsc::Receiver<U>>,
}
impl<T, U> InmemoryTypedTransport<T, U> {
pub fn new(tx: mpsc::Sender<T>, rx: mpsc::Receiver<U>) -> Self {
Self {
tx,
rx: Mutex::new(rx),
}
}
/// Creates a pair of connected transports using `buffer` as maximum
/// channel capacity for each
pub fn pair(buffer: usize) -> (InmemoryTypedTransport<T, U>, InmemoryTypedTransport<U, T>) {
let (t_tx, t_rx) = mpsc::channel(buffer);
let (u_tx, u_rx) = mpsc::channel(buffer);
(
InmemoryTypedTransport::new(t_tx, u_rx),
InmemoryTypedTransport::new(u_tx, t_rx),
)
}
}
#[async_trait]
impl<T, U> Reconnectable for InmemoryTypedTransport<T, U>
where
T: Send,
U: Send,
{
/// Once the underlying channels have closed, there is no way for this transport to
/// re-establish those channels; therefore, reconnecting will always fail with
/// [`ErrorKind::Unsupported`]
///
/// [`ErrorKind::Unsupported`]: io::ErrorKind::Unsupported
async fn reconnect(&mut self) -> io::Result<()> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
}
#[async_trait]
impl<T, U> TypedTransport for InmemoryTypedTransport<T, U>
where
T: Send,
U: Send,
{
type Input = U;
type Output = T;
fn try_read(&self) -> io::Result<Option<Self::Input>> {
match self.rx.lock().unwrap().try_recv() {
Ok(x) => Ok(Some(x)),
Err(TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
Err(TryRecvError::Disconnected) => Ok(None),
}
}
fn try_write(&self, value: Self::Output) -> io::Result<()> {
match self.tx.try_send(value) {
Ok(()) => Ok(()),
Err(TrySendError::Full(_)) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
Err(TrySendError::Closed(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)),
}
}
async fn ready(&self, interest: Interest) -> io::Result<Ready> {
let mut status = Ready::EMPTY;
if interest.is_readable() {
// TODO: Replace `self.is_rx_closed()` with `self.rx.is_closed()` once the tokio issue
// is resolved that adds `is_closed` to the `mpsc::Receiver`
//
// See https://github.com/tokio-rs/tokio/issues/4638
status |= if self.is_rx_closed() && self.buf.lock().unwrap().is_none() {
Ready::READ_CLOSED
} else {
Ready::READABLE
};
}
if interest.is_writable() {
status |= if self.tx.is_closed() {
Ready::WRITE_CLOSED
} else {
Ready::WRITABLE
};
}
Ok(status)
}
}
Loading…
Cancel
Save