Fix bug where synchronization of transport didn't respect codec change

pull/146/head
Chip Senkbeil 2 years ago
parent e0c8d94592
commit 443326af62
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -342,8 +342,8 @@ where
debug!("[Conn {id}] Deriving future OTP for reauthentication");
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
// Synchronize using the provided backup
debug!("[Conn {id}] Synchronizing frame state");
// Grab the old backup and swap it into our transport
debug!("[Conn {id}] Acquiring backup for existing connection");
match x.await {
Ok(backup) => {
transport.backup = backup;
@ -352,6 +352,9 @@ where
warn!("[Conn {id}] Missing backup");
}
}
// Synchronize using the provided backup
debug!("[Conn {id}] Synchronizing frame state");
transport.synchronize().await?;
// Store the id, OTP, and backup retrieval in our database
@ -617,8 +620,8 @@ mod tests {
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
// Drop the transport to cause the server connection to fail waiting on connect type
drop(t1);
// Send some garbage that is not the connection type
t1.write_frame(Frame::new(b"hello")).await.unwrap();
// Server should fail
task.await.unwrap_err();
@ -651,6 +654,7 @@ mod tests {
t1.authenticate(DummyAuthHandler).await.unwrap_err();
// Drop the transport so we kill the server-side connection
// NOTE: If we don't drop here, the above authentication failure won't kill the server
drop(t1);
// Server should fail
@ -659,38 +663,274 @@ mod tests {
#[test(tokio::test)]
async fn server_should_fail_if_unable_to_exchange_otp_for_reauthentication_with_new_client() {
todo!();
let (mut t1, t2) = FramedTransport::pair(100);
let verifier = Verifier::none();
let keychain = Keychain::new();
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
.await
.unwrap()
});
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
// Send type to indicate a new connection
t1.write_frame_for(&ConnectType::Connect).await.unwrap();
// Receive the connection id
let _id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Pass verification using the dummy handler since our verifier supports no authentication
t1.authenticate(DummyAuthHandler).await.unwrap();
// Send some garbage to fail the exchange
t1.write_frame(Frame::new(b"hello")).await.unwrap();
// Server should fail
task.await.unwrap_err();
}
#[test(tokio::test)]
async fn server_should_fail_if_existing_client_id_is_invalid() {
todo!();
let (mut t1, t2) = FramedTransport::pair(100);
let verifier = Verifier::none();
let keychain = Keychain::new();
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
.await
.unwrap()
});
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
// Send type to indicate an existing connection, which should cause the server-side to fail
// because there is no matching id
t1.write_frame_for(&ConnectType::Reconnect {
id: 1234,
otp: HeapSecretKey::generate(32)
.unwrap()
.unprotected_into_bytes(),
})
.await
.unwrap();
// Server should fail
task.await.unwrap_err();
}
#[test(tokio::test)]
async fn server_should_fail_if_existing_client_otp_is_invalid() {
todo!();
}
let (mut t1, t2) = FramedTransport::pair(100);
let verifier = Verifier::none();
let keychain = Keychain::new();
#[test(tokio::test)]
async fn server_should_fail_if_unable_to_send_id_to_existing_client() {
todo!();
keychain
.insert(
1234.to_string(),
HeapSecretKey::generate(32).unwrap(),
oneshot::channel().1,
)
.await;
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
.await
.unwrap()
});
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
// Send type to indicate an existing connection, which should cause the server-side to fail
// because the OTP is wrong for the given id
t1.write_frame_for(&ConnectType::Reconnect {
id: 1234,
otp: HeapSecretKey::generate(32)
.unwrap()
.unprotected_into_bytes(),
})
.await
.unwrap();
// Server should fail
task.await.unwrap_err();
}
#[test(tokio::test)]
async fn server_should_fail_if_unable_to_exchange_otp_for_reauthentication_with_existing_client(
) {
todo!();
let (mut t1, t2) = FramedTransport::pair(100);
let verifier = Verifier::none();
let keychain = Keychain::new();
let key = HeapSecretKey::generate(32).unwrap();
keychain
.insert(1234.to_string(), key.clone(), oneshot::channel().1)
.await;
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
.await
.unwrap()
});
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
// Send type to indicate an existing connection, which should cause the server-side to fail
// because the OTP is wrong for the given id
t1.write_frame_for(&ConnectType::Reconnect {
id: 1234,
otp: key.unprotected_into_bytes(),
})
.await
.unwrap();
// Receive a new client id
let _id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Send garbage to fail the otp exchange
t1.write_frame(Frame::new(b"hello")).await.unwrap();
// Server should fail
task.await.unwrap_err();
}
#[test(tokio::test)]
async fn server_should_fail_if_unable_to_synchronize_with_existing_client() {
todo!();
let (mut t1, t2) = FramedTransport::pair(100);
let verifier = Verifier::none();
let keychain = Keychain::new();
let key = HeapSecretKey::generate(32).unwrap();
keychain
.insert(1234.to_string(), key.clone(), oneshot::channel().1)
.await;
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn(async move {
Connection::server(t2.into_inner(), &verifier, keychain)
.await
.unwrap()
});
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
// Send type to indicate an existing connection, which should cause the server-side to fail
// because the OTP is wrong for the given id
t1.write_frame_for(&ConnectType::Reconnect {
id: 1234,
otp: key.unprotected_into_bytes(),
})
.await
.unwrap();
// Receive a new client id
let _id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Perform otp exchange
let _otp = t1.exchange_keys().await.unwrap();
// Send garbage to fail synchronization
t1.write_frame(b"hello").await.unwrap();
// Server should fail
task.await.unwrap_err();
}
#[test(tokio::test)]
async fn server_should_succeed_if_establishes_connection_with_new_client() {
todo!();
let (mut t1, t2) = FramedTransport::pair(100);
let verifier = Verifier::none();
let keychain = Keychain::new();
let key = HeapSecretKey::generate(32).unwrap();
keychain
.insert(1234.to_string(), key.clone(), {
// Create a custom backup we'll use to replay frames from the server-side
let mut backup = Backup::new();
backup.push_frame(Frame::new(b"hello"));
backup.push_frame(Frame::new(b"world"));
backup.increment_sent_cnt();
backup.increment_sent_cnt();
let (tx, rx) = oneshot::channel();
tx.send(backup).unwrap();
rx
})
.await;
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio::spawn({
let keychain = keychain.clone();
async move {
Connection::server(t2.into_inner(), &verifier, keychain)
.await
.unwrap()
}
});
// Perform first step of completing client-side of handshake
t1.client_handshake().await.unwrap();
// Send type to indicate an existing connection, which should cause the server-side to fail
// because the OTP is wrong for the given id
t1.write_frame_for(&ConnectType::Reconnect {
id: 1234,
otp: key.unprotected_into_bytes(),
})
.await
.unwrap();
// Receive a new client id
let id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
// Perform otp exchange
let otp = t1.exchange_keys().await.unwrap();
// Queue up some frames to send to the server
t1.backup.clear();
t1.backup.push_frame(Frame::new(b"foo"));
t1.backup.push_frame(Frame::new(b"bar"));
t1.backup.increment_sent_cnt();
t1.backup.increment_sent_cnt();
// Perform synchronization
t1.synchronize().await.unwrap();
// Verify that we received frames from the server
assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
// Server connection should be established, and have received some replayed frames
let mut server = task.await.unwrap();
assert_eq!(server.read_frame().await.unwrap().unwrap(), b"foo");
assert_eq!(server.read_frame().await.unwrap().unwrap(), b"bar");
// Validate the connection ids match
assert_eq!(server.id(), id);
// Validate the OTP was stored in our keychain
assert!(
keychain.has_key("1234", otp.into_heap_secret_key()).await,
"Missing OTP"
);
}
#[test(tokio::test)]

@ -323,18 +323,21 @@ impl<T: Transport> FramedTransport<T> {
F: TryInto<Frame<'a>>,
F::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
// Grab the frame to send
let frame = frame
.try_into()
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?;
// Encode the frame and store it in our outgoing queue
let frame = self.codec.encode(
frame
.try_into()
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?,
)?;
frame.write(&mut self.outgoing)?;
self.codec
.encode(frame.as_borrowed())?
.write(&mut self.outgoing)?;
// Once the frame enters our queue, we count it as written, even if it isn't fully flushed
self.backup.increment_sent_cnt();
// Then we store the frame for the future in case we need to retry sending it later
// Then we store the raw frame (non-encoded) for the future in case we need to retry
// sending it later (possibly with a different codec)
self.backup.push_frame(frame);
// Attempt to write everything in our queue
@ -476,7 +479,9 @@ impl<T: Transport> FramedTransport<T> {
backup.truncate_front(resend_cnt.try_into().expect("Cannot cast usize to u64"));
debug!("Sending {resend_cnt} frames");
backup.write(&mut this.outgoing)?;
for frame in backup.frames() {
this.try_write_frame(frame.as_borrowed())?;
}
this.flush().await?;
// Receive all expected frames, placing their contents into our incoming queue
@ -495,7 +500,10 @@ impl<T: Transport> FramedTransport<T> {
),
)
})?;
frame.write(&mut this.incoming)?;
// Encode our frame and write it to be queued in our incoming data
// NOTE: We have to do encoding here as incoming bytes are expected to be encoded
this.codec.encode(frame)?.write(&mut this.incoming)?;
}
// Catch up our read count as we can have the case where the other side has a higher
@ -815,10 +823,10 @@ where
}
impl FramedTransport<InmemoryTransport> {
/// Produces a pair of inmemory transports that are connected to each other using
/// a standard codec.
/// Produces a pair of inmemory transports that are connected to each other using a
/// [`PlainCodec`].
///
/// Sets the buffer for message passing for each underlying transport to the given buffer size
/// Sets the buffer for message passing for each underlying transport to the given buffer size.
pub fn pair(
buffer: usize,
) -> (
@ -830,6 +838,11 @@ impl FramedTransport<InmemoryTransport> {
let b = FramedTransport::new(b, Box::new(PlainCodec::new()));
(a, b)
}
/// Links the underlying transports together using [`InmemoryTransport::link`].
pub fn link(&mut self, other: &mut Self, buffer: usize) {
self.inner.link(&mut other.inner, buffer)
}
}
#[cfg(test)]
@ -1811,6 +1824,41 @@ mod tests {
assert_eq!(t2.backup.frame_cnt(), 0, "Wrong frame cnt");
}
#[test(tokio::test)]
async fn synchronize_should_work_even_if_codec_changes_between_attempts() {
let (mut t1, _t1_other) = FramedTransport::pair(100);
let (mut t2, _t2_other) = FramedTransport::pair(100);
// Send some frames from each side
t1.write_frame(Frame::new(b"hello")).await.unwrap();
t1.write_frame(Frame::new(b"world")).await.unwrap();
t2.write_frame(Frame::new(b"foo")).await.unwrap();
t2.write_frame(Frame::new(b"bar")).await.unwrap();
// Drop the other transports, link our real transports together, and change the codec
drop(_t1_other);
drop(_t2_other);
t1.link(&mut t2, 100);
let codec = EncryptionCodec::new_xchacha20poly1305(Default::default());
t1.codec = Box::new(codec.clone());
t2.codec = Box::new(codec);
// Spawn a separate task to do synchronization so we don't deadlock
let task = tokio::spawn(async move {
t2.synchronize().await.unwrap();
t2
});
t1.synchronize().await.unwrap();
// Verify that we get the appropriate frames from both sides
let mut t2 = task.await.unwrap();
assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"foo");
assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"bar");
assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello");
assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world");
}
#[test(tokio::test)]
async fn handshake_should_configure_transports_with_matching_codec() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);

@ -1,7 +1,5 @@
use super::{Frame, OwnedFrame};
use bytes::BytesMut;
use std::collections::VecDeque;
use std::io;
/// Maximum size (in bytes) for saved frames (256MiB)
const MAX_BACKUP_SIZE: usize = 256 * 1024 * 1024;
@ -110,7 +108,7 @@ impl Backup {
/// ### Note
///
/// Like all other modifications, this will do nothing if the backup is frozen.
pub(super) fn increment_sent_cnt(&mut self) {
pub(crate) fn increment_sent_cnt(&mut self) {
if !self.frozen {
self.sent_cnt += 1;
}
@ -153,7 +151,7 @@ impl Backup {
/// ### Note
///
/// Like all other modifications, this will do nothing if the backup is frozen.
pub(super) fn push_frame(&mut self, frame: Frame) {
pub(crate) fn push_frame(&mut self, frame: Frame) {
if self.max_backup_size > 0 && !self.frozen {
self.current_backup_size += frame.len();
self.frames.push_back(frame.into_owned());
@ -179,15 +177,9 @@ impl Backup {
self.frames.len()
}
/// Writes all stored frames to the `dst` by invoking [`Frame::write`] in sequence.
///
/// [`Frame::write`]: super::Frame::write
pub(super) fn write(&self, dst: &mut BytesMut) -> io::Result<()> {
for frame in self.frames.iter() {
frame.write(dst)?;
}
Ok(())
/// Returns an iterator over the frames contained in the backup.
pub(super) fn frames(&self) -> impl Iterator<Item = &Frame> {
self.frames.iter()
}
/// Truncates the stored frames to be no larger than `size` total frames by popping from the

@ -57,6 +57,25 @@ impl InmemoryTransport {
(transport, Self::new(tx, rx))
}
/// Links two independent [`InmemoryTransport`] together by dropping their internal channels
/// and generating new ones of `buffer` capacity to connect these transports.
///
/// ### Note
///
/// This will drop any pre-existing data in the internal storage to avoid corruption.
pub fn link(&mut self, other: &mut InmemoryTransport, buffer: usize) {
let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
self.buf = Mutex::new(None);
self.tx = outgoing_tx;
self.rx = Mutex::new(incoming_rx);
other.buf = Mutex::new(None);
other.tx = incoming_tx;
other.rx = Mutex::new(outgoing_rx);
}
/// Returns true if the read channel is closed, meaning it will no longer receive more data.
/// This does not factor in data remaining in the internal buffer, meaning that this may return
/// true while the transport still has data remaining in the internal buffer.

Loading…
Cancel
Save