Refactor exchange key functionality into public method on FramedTransport

pull/146/head
Chip Senkbeil 2 years ago
parent 9464b9e4ad
commit 5a4aceccb7
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -44,9 +44,12 @@ pub trait Authenticator: Send {
macro_rules! write_frame {
($transport:expr, $data:expr) => {{
$transport
.write_frame(utils::serialize_to_vec(&$data)?)
.await?
let data = utils::serialize_to_vec(&$data)?;
if log_enabled!(Level::Trace) {
trace!("Writing data as frame: {data:?}");
}
$transport.write_frame(data).await?
}};
}
@ -67,7 +70,21 @@ macro_rules! next_frame_as {
io::Error::new(io::ErrorKind::UnexpectedEof, "Transport closed early")
})?;
utils::deserialize_from_slice::<$type>(frame.as_item())?
match utils::deserialize_from_slice::<$type>(frame.as_item()) {
Ok(frame) => frame,
Err(x) => {
if log_enabled!(Level::Trace) {
trace!(
"Failed to deserialize frame item as {}: {:?}",
stringify!($type),
frame.as_item()
);
}
Err(x)?;
unreachable!();
}
}
}};
}

@ -214,10 +214,13 @@ where
let (tx, mut rx) = mpsc::channel::<Response<H::Response>>(1);
// Perform a handshake to ensure that the connection is properly established
let mut transport: FramedTransport<T> = FramedTransport::plain(transport);
if let Err(x) = transport.server_handshake().await {
terminate_connection!(@error "[Conn {id}] Handshake failed: {x}");
}
let mut transport: FramedTransport<T> =
match FramedTransport::from_server_handshake(transport).await {
Ok(x) => x,
Err(x) => {
terminate_connection!(@error "[Conn {id}] Handshake failed: {x}");
}
};
// Perform authentication to ensure the connection is valid
match Weak::upgrade(&verifier) {

@ -304,6 +304,16 @@ impl<T: Transport> FramedTransport<T> {
pub async fn client_handshake(&mut self) -> io::Result<()> {
self.handshake(Handshake::client()).await
}
/// Shorthand for creating a [`FramedTransport`] with a [`PlainCodec`] and then immediately
/// performing a [`server_handshake`], returning the updated [`FramedTransport`] on success.
///
/// [`client_handshake`]: FramedTransport::client_handshake
#[inline]
pub async fn from_server_handshake(transport: T) -> io::Result<Self> {
let mut transport = Self::plain(transport);
transport.server_handshake().await?;
Ok(transport)
}
/// Perform the server-side of a handshake. See [`handshake`] for more details.
///
@ -488,33 +498,7 @@ impl<T: Transport> FramedTransport<T> {
))
}
Some(ty) => {
#[derive(Serialize, Deserialize)]
struct KeyExchangeData {
/// Bytes of the public key
#[serde(with = "serde_bytes")]
public_key: PublicKeyBytes,
/// Randomly generated salt
#[serde(with = "serde_bytes")]
salt: Salt,
}
debug!("[{log_label}] Exchanging public key and salt");
let exchange = KeyExchange::default();
write_frame!(KeyExchangeData {
public_key: exchange.pk_bytes(),
salt: *exchange.salt(),
});
// TODO: This key only works because it happens to be 32 bytes and our encryption
// also wants a 32-byte key. Once we introduce new encryption algorithms that
// are not using 32-byte keys, the key exchange will need to support deriving
// other length keys.
trace!("[{log_label}] Waiting on public key and salt from other side");
let data = next_frame_as!(KeyExchangeData);
trace!("[{log_label}] Deriving shared secret key");
let key = exchange.derive_shared_secret(data.public_key, data.salt)?;
let key = self.exchange_keys_impl(log_label).await?;
Some(ty.new_codec(key.unprotected_as_bytes())?)
}
None => None,
@ -539,6 +523,65 @@ impl<T: Transport> FramedTransport<T> {
Ok(codec)
}
/// Places the transport into key-exchange mode where it attempts to derive a shared secret key
/// with the other transport.
pub async fn exchange_keys(&mut self) -> io::Result<SecretKey32> {
self.exchange_keys_impl("").await
}
async fn exchange_keys_impl(&mut self, label: &str) -> io::Result<SecretKey32> {
let log_label = if label.is_empty() {
String::new()
} else {
format!("[{label}] ")
};
macro_rules! write_frame {
($data:expr) => {{
self.write_frame(utils::serialize_to_vec(&$data)?).await?
}};
}
macro_rules! next_frame_as {
($type:ty) => {{
let frame = self.read_frame().await?.ok_or_else(|| {
io::Error::new(io::ErrorKind::UnexpectedEof, "Transport closed early")
})?;
utils::deserialize_from_slice::<$type>(frame.as_item())?
}};
}
#[derive(Serialize, Deserialize)]
struct KeyExchangeData {
/// Bytes of the public key
#[serde(with = "serde_bytes")]
public_key: PublicKeyBytes,
/// Randomly generated salt
#[serde(with = "serde_bytes")]
salt: Salt,
}
debug!("{log_label}Exchanging public key and salt");
let exchange = KeyExchange::default();
write_frame!(KeyExchangeData {
public_key: exchange.pk_bytes(),
salt: *exchange.salt(),
});
// TODO: This key only works because it happens to be 32 bytes and our encryption
// also wants a 32-byte key. Once we introduce new encryption algorithms that
// are not using 32-byte keys, the key exchange will need to support deriving
// other length keys.
trace!("{log_label}Waiting on public key and salt from other side");
let data = next_frame_as!(KeyExchangeData);
trace!("{log_label}Deriving shared secret key");
let key = exchange.derive_shared_secret(data.public_key, data.salt)?;
Ok(key)
}
}
#[async_trait]
@ -1361,4 +1404,46 @@ mod tests {
let err = t1.server_handshake().await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[test(tokio::test)]
async fn exchange_keys_should_fail_if_unable_to_send_exchange_data_to_other_side() {
let (mut t1, t2) = FramedTransport::test_pair(100);
// Drop the other side to ensure that the exchange fails at the beginning
drop(t2);
// Perform key exchange and verify error is as expected
assert_eq!(
t1.exchange_keys().await.unwrap_err().kind(),
io::ErrorKind::WriteZero
);
}
#[test(tokio::test)]
async fn exchange_keys_should_fail_if_received_invalid_exchange_data() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Queue up an invalid exchange response
t2.write_frame(b"some invalid frame").await.unwrap();
// Perform key exchange and verify error is as expected
assert_eq!(
t1.exchange_keys().await.unwrap_err().kind(),
io::ErrorKind::InvalidData
);
}
#[test(tokio::test)]
async fn exchange_keys_should_return_shared_secret_key_if_successful() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
// Spawn a task to avoid deadlocking
let task = tokio::spawn(async move { t2.exchange_keys().await.unwrap() });
// Perform key exchange
let key = t1.exchange_keys().await.unwrap();
// Validate that the keys on both sides match
assert_eq!(key, task.await.unwrap());
}
}

Loading…
Cancel
Save