pull/146/head
Chip Senkbeil 2 years ago
parent d375298d5b
commit 5b12d3d7fe
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -3,14 +3,58 @@ use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Represents the result of a request to the database.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum KeychainResult<T> {
/// Id was not found in the database.
InvalidId,
/// Password match for an id failed.
InvalidPassword,
/// Successful match of id and password, removing from keychain and returning data `T`.
Ok(T),
}
impl<T> KeychainResult<T> {
pub fn is_invalid_id(&self) -> bool {
matches!(self, Self::InvalidId)
}
pub fn is_invalid_password(&self) -> bool {
matches!(self, Self::InvalidPassword)
}
pub fn is_invalid(&self) -> bool {
matches!(self, Self::InvalidId | Self::InvalidPassword)
}
pub fn is_ok(&self) -> bool {
matches!(self, Self::Ok(_))
}
pub fn into_ok(self) -> Option<T> {
match self {
Self::Ok(x) => Some(x),
_ => None,
}
}
}
impl<T> From<KeychainResult<T>> for Option<T> {
fn from(result: KeychainResult<T>) -> Self {
result.into_ok()
}
}
/// Manages keys with associated ids. Cloning will result in a copy pointing to the same underlying
/// storage, which enables support of managing the keys across multiple threads.
#[derive(Clone, Debug)]
pub struct Keychain {
map: Arc<RwLock<HashMap<String, HeapSecretKey>>>,
pub struct Keychain<T = ()> {
map: Arc<RwLock<HashMap<String, (HeapSecretKey, T)>>>,
}
impl Keychain {
impl<T> Keychain<T> {
/// Creates a new keychain without any keys.
pub fn new() -> Self {
Self {
@ -18,10 +62,14 @@ impl Keychain {
}
}
/// Stores a new `key` by a given `id`, returning the old key if there was one already
/// registered.
pub async fn insert(&self, id: impl Into<String>, key: HeapSecretKey) -> Option<HeapSecretKey> {
self.map.write().await.insert(id.into(), key)
/// Stores a new `key` and `data` by a given `id`, returning the old data associated with the
/// id if there was one already registered.
pub async fn insert(&self, id: impl Into<String>, key: HeapSecretKey, data: T) -> Option<T> {
self.map
.write()
.await
.insert(id.into(), (key, data))
.map(|(_, data)| data)
}
/// Checks if there is a key with the given `id` that matches the provided `key`.
@ -30,13 +78,44 @@ impl Keychain {
.read()
.await
.get(id.as_ref())
.map(|k| key.eq(k))
.map(|(k, _)| key.eq(k))
.unwrap_or(false)
}
/// Removes a key by a given `id`, returning the key if there was one found for the given id.
pub async fn remove(&self, id: impl AsRef<str>) -> Option<HeapSecretKey> {
self.map.write().await.remove(id.as_ref())
/// Removes a key and its data by a given `id`, returning the data if the `id` exists.
pub async fn remove(&self, id: impl AsRef<str>) -> Option<T> {
self.map
.write()
.await
.remove(id.as_ref())
.map(|(_, data)| data)
}
/// Checks if there is a key with the given `id` that matches the provided `key`, returning the
/// data if the `id` exists and the `key` matches.
pub async fn remove_if_has_key(
&self,
id: impl AsRef<str>,
key: impl PartialEq<HeapSecretKey>,
) -> KeychainResult<T> {
let id = id.as_ref();
let mut lock = self.map.write().await;
match lock.get(id) {
Some((k, _)) if key.eq(k) => {
drop(k);
KeychainResult::Ok(lock.remove(id).unwrap().1)
}
Some(_) => KeychainResult::InvalidPassword,
None => KeychainResult::InvalidId,
}
}
}
impl Keychain<()> {
/// Stores a new `key by a given `id`.
pub async fn put(&self, id: impl Into<String>, key: HeapSecretKey) {
self.insert(id, key, ()).await;
}
}
@ -46,11 +125,22 @@ impl Default for Keychain {
}
}
impl From<HashMap<String, HeapSecretKey>> for Keychain {
impl<T> From<HashMap<String, (HeapSecretKey, T)>> for Keychain<T> {
/// Creates a new keychain populated with the provided `map`.
fn from(map: HashMap<String, HeapSecretKey>) -> Self {
fn from(map: HashMap<String, (HeapSecretKey, T)>) -> Self {
Self {
map: Arc::new(RwLock::new(map)),
}
}
}
impl From<HashMap<String, HeapSecretKey>> for Keychain<()> {
/// Creates a new keychain populated with the provided `map`.
fn from(map: HashMap<String, HeapSecretKey>) -> Self {
Self::from(
map.into_iter()
.map(|(id, key)| (id, (key, ())))
.collect::<HashMap<String, (HeapSecretKey, ())>>(),
)
}
}

@ -1,6 +1,6 @@
use super::{
authentication::{AuthHandler, Authenticate, Keychain, Verifier},
FramedTransport, HeapSecretKey, Reconnectable, Transport,
authentication::{AuthHandler, Authenticate, Keychain, KeychainResult, Verifier},
Backup, FramedTransport, HeapSecretKey, Reconnectable, Transport,
};
use async_trait::async_trait;
use log::*;
@ -108,14 +108,12 @@ where
io::Error::new(io::ErrorKind::Other, "Missing connection id frame")
})?;
debug!("[Conn {id}] Resetting id to {new_id}");
*id = new_id;
// Derive an OTP for reauthentication
debug!("[Conn {new_id}] Deriving future OTP for reauthentication");
let new_reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
debug!("[Conn {id}] Deriving future OTP for reauthentication");
*reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
// Update our connection's id and reauth OTP
*id = new_id;
*reauth_otp = new_reauth_otp;
Ok(())
}
@ -135,6 +133,7 @@ where
result?;
// Perform synchronization
debug!("[Conn {id}] Synchronizing frame state");
transport.synchronize().await?;
Ok(())
@ -236,7 +235,11 @@ where
/// from our database
/// 3. Restores pre-existing state using the provided backup, replaying any missing frames and
/// receiving any frames from the other side
pub async fn server(transport: T, verifier: &Verifier, keychain: Keychain) -> io::Result<Self> {
pub async fn server(
transport: T,
verifier: &Verifier,
keychain: Keychain<Backup>,
) -> io::Result<Self> {
let id: ConnectionId = rand::random();
// Perform a handshake to ensure that the connection is properly established and encrypted
@ -271,34 +274,50 @@ where
// Perform authentication to ensure the connection is valid
debug!("[Conn {id}] Verifying connection");
verifier.verify(&mut transport).await?;
// Derive an OTP for reauthentication
debug!("[Conn {id}] Deriving future OTP for reauthentication");
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
}
ConnectType::Reconnect { id: other_id, otp } => {
let reauth_otp = HeapSecretKey::from(otp);
debug!("[Conn {id}] Checking if {other_id} exists");
if let Some(otp) = keychain.remove(other_id.to_string()).await {
debug!("[Conn {id}] Checking if OTP matches for {other_id}");
if reauth_otp != otp {
// Re-add the existing OTP since we didn't match
keychain.insert(other_id.to_string(), otp).await;
debug!("[Conn {id}] Checking if {other_id} exists and has matching OTP");
match keychain
.remove_if_has_key(other_id.to_string(), reauth_otp)
.await
{
KeychainResult::Ok(backup) => {
// Communicate the connection id
debug!("[Conn {id}] Telling other side to change connection id");
transport.write_frame_for(&id).await?;
// Derive an OTP for reauthentication
debug!("[Conn {id}] Deriving future OTP for reauthentication");
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
// Synchronize using the provided backup
debug!("[Conn {id}] Synchronizing frame state");
transport.backup = backup;
transport.synchronize().await?;
}
KeychainResult::InvalidPassword => {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"Invalid OTP for reconnect",
));
}
KeychainResult::InvalidId => {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"Invalid OTP",
"Invalid id for reconnect",
));
}
}
// Communicate the connection id
debug!("[Conn {id}] Telling other side to change connection id");
transport.write_frame_for(&id).await?;
}
}
// Derive an OTP for reauthentication
debug!("[Conn {id}] Deriving future OTP for reauthentication");
let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
// Store the id and OTP in our database
keychain.insert(id.to_string(), reauth_otp).await;

Loading…
Cancel
Save