|
|
|
@ -304,6 +304,16 @@ impl<T: Transport> FramedTransport<T> {
|
|
|
|
|
pub async fn client_handshake(&mut self) -> io::Result<()> {
|
|
|
|
|
self.handshake(Handshake::client()).await
|
|
|
|
|
}
|
|
|
|
|
/// Shorthand for creating a [`FramedTransport`] with a [`PlainCodec`] and then immediately
|
|
|
|
|
/// performing a [`server_handshake`], returning the updated [`FramedTransport`] on success.
|
|
|
|
|
///
|
|
|
|
|
/// [`client_handshake`]: FramedTransport::client_handshake
|
|
|
|
|
#[inline]
|
|
|
|
|
pub async fn from_server_handshake(transport: T) -> io::Result<Self> {
|
|
|
|
|
let mut transport = Self::plain(transport);
|
|
|
|
|
transport.server_handshake().await?;
|
|
|
|
|
Ok(transport)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Perform the server-side of a handshake. See [`handshake`] for more details.
|
|
|
|
|
///
|
|
|
|
@ -488,33 +498,7 @@ impl<T: Transport> FramedTransport<T> {
|
|
|
|
|
))
|
|
|
|
|
}
|
|
|
|
|
Some(ty) => {
|
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
|
|
|
struct KeyExchangeData {
|
|
|
|
|
/// Bytes of the public key
|
|
|
|
|
#[serde(with = "serde_bytes")]
|
|
|
|
|
public_key: PublicKeyBytes,
|
|
|
|
|
|
|
|
|
|
/// Randomly generated salt
|
|
|
|
|
#[serde(with = "serde_bytes")]
|
|
|
|
|
salt: Salt,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
debug!("[{log_label}] Exchanging public key and salt");
|
|
|
|
|
let exchange = KeyExchange::default();
|
|
|
|
|
write_frame!(KeyExchangeData {
|
|
|
|
|
public_key: exchange.pk_bytes(),
|
|
|
|
|
salt: *exchange.salt(),
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// TODO: This key only works because it happens to be 32 bytes and our encryption
|
|
|
|
|
// also wants a 32-byte key. Once we introduce new encryption algorithms that
|
|
|
|
|
// are not using 32-byte keys, the key exchange will need to support deriving
|
|
|
|
|
// other length keys.
|
|
|
|
|
trace!("[{log_label}] Waiting on public key and salt from other side");
|
|
|
|
|
let data = next_frame_as!(KeyExchangeData);
|
|
|
|
|
|
|
|
|
|
trace!("[{log_label}] Deriving shared secret key");
|
|
|
|
|
let key = exchange.derive_shared_secret(data.public_key, data.salt)?;
|
|
|
|
|
let key = self.exchange_keys_impl(log_label).await?;
|
|
|
|
|
Some(ty.new_codec(key.unprotected_as_bytes())?)
|
|
|
|
|
}
|
|
|
|
|
None => None,
|
|
|
|
@ -539,6 +523,65 @@ impl<T: Transport> FramedTransport<T> {
|
|
|
|
|
|
|
|
|
|
Ok(codec)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Places the transport into key-exchange mode where it attempts to derive a shared secret key
|
|
|
|
|
/// with the other transport.
|
|
|
|
|
pub async fn exchange_keys(&mut self) -> io::Result<SecretKey32> {
|
|
|
|
|
self.exchange_keys_impl("").await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn exchange_keys_impl(&mut self, label: &str) -> io::Result<SecretKey32> {
|
|
|
|
|
let log_label = if label.is_empty() {
|
|
|
|
|
String::new()
|
|
|
|
|
} else {
|
|
|
|
|
format!("[{label}] ")
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
macro_rules! write_frame {
|
|
|
|
|
($data:expr) => {{
|
|
|
|
|
self.write_frame(utils::serialize_to_vec(&$data)?).await?
|
|
|
|
|
}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
macro_rules! next_frame_as {
|
|
|
|
|
($type:ty) => {{
|
|
|
|
|
let frame = self.read_frame().await?.ok_or_else(|| {
|
|
|
|
|
io::Error::new(io::ErrorKind::UnexpectedEof, "Transport closed early")
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
utils::deserialize_from_slice::<$type>(frame.as_item())?
|
|
|
|
|
}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
|
|
|
struct KeyExchangeData {
|
|
|
|
|
/// Bytes of the public key
|
|
|
|
|
#[serde(with = "serde_bytes")]
|
|
|
|
|
public_key: PublicKeyBytes,
|
|
|
|
|
|
|
|
|
|
/// Randomly generated salt
|
|
|
|
|
#[serde(with = "serde_bytes")]
|
|
|
|
|
salt: Salt,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
debug!("{log_label}Exchanging public key and salt");
|
|
|
|
|
let exchange = KeyExchange::default();
|
|
|
|
|
write_frame!(KeyExchangeData {
|
|
|
|
|
public_key: exchange.pk_bytes(),
|
|
|
|
|
salt: *exchange.salt(),
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// TODO: This key only works because it happens to be 32 bytes and our encryption
|
|
|
|
|
// also wants a 32-byte key. Once we introduce new encryption algorithms that
|
|
|
|
|
// are not using 32-byte keys, the key exchange will need to support deriving
|
|
|
|
|
// other length keys.
|
|
|
|
|
trace!("{log_label}Waiting on public key and salt from other side");
|
|
|
|
|
let data = next_frame_as!(KeyExchangeData);
|
|
|
|
|
|
|
|
|
|
trace!("{log_label}Deriving shared secret key");
|
|
|
|
|
let key = exchange.derive_shared_secret(data.public_key, data.salt)?;
|
|
|
|
|
Ok(key)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
@ -1361,4 +1404,46 @@ mod tests {
|
|
|
|
|
let err = t1.server_handshake().await.unwrap_err();
|
|
|
|
|
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
|
async fn exchange_keys_should_fail_if_unable_to_send_exchange_data_to_other_side() {
|
|
|
|
|
let (mut t1, t2) = FramedTransport::test_pair(100);
|
|
|
|
|
|
|
|
|
|
// Drop the other side to ensure that the exchange fails at the beginning
|
|
|
|
|
drop(t2);
|
|
|
|
|
|
|
|
|
|
// Perform key exchange and verify error is as expected
|
|
|
|
|
assert_eq!(
|
|
|
|
|
t1.exchange_keys().await.unwrap_err().kind(),
|
|
|
|
|
io::ErrorKind::WriteZero
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
|
async fn exchange_keys_should_fail_if_received_invalid_exchange_data() {
|
|
|
|
|
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
|
|
|
|
|
|
|
|
|
// Queue up an invalid exchange response
|
|
|
|
|
t2.write_frame(b"some invalid frame").await.unwrap();
|
|
|
|
|
|
|
|
|
|
// Perform key exchange and verify error is as expected
|
|
|
|
|
assert_eq!(
|
|
|
|
|
t1.exchange_keys().await.unwrap_err().kind(),
|
|
|
|
|
io::ErrorKind::InvalidData
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
|
async fn exchange_keys_should_return_shared_secret_key_if_successful() {
|
|
|
|
|
let (mut t1, mut t2) = FramedTransport::test_pair(100);
|
|
|
|
|
|
|
|
|
|
// Spawn a task to avoid deadlocking
|
|
|
|
|
let task = tokio::spawn(async move { t2.exchange_keys().await.unwrap() });
|
|
|
|
|
|
|
|
|
|
// Perform key exchange
|
|
|
|
|
let key = t1.exchange_keys().await.unwrap();
|
|
|
|
|
|
|
|
|
|
// Validate that the keys on both sides match
|
|
|
|
|
assert_eq!(key, task.await.unwrap());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|