You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
distant/distant-net/src/authentication.rs

477 lines
15 KiB
Rust

use std::io;
use async_trait::async_trait;
use distant_auth::msg::*;
use distant_auth::{AuthHandler, Authenticate, Authenticator};
use log::*;
use crate::common::{utils, FramedTransport, Transport};
macro_rules! write_frame {
($transport:expr, $data:expr) => {{
let data = utils::serialize_to_vec(&$data)?;
if log_enabled!(Level::Trace) {
trace!("Writing data as frame: {data:?}");
}
$transport.write_frame(data).await?
}};
}
macro_rules! next_frame_as {
($transport:expr, $type:ident, $variant:ident) => {{
match { next_frame_as!($transport, $type) } {
$type::$variant(x) => x,
x => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Unexpected frame: {x:?}"),
))
}
}
}};
($transport:expr, $type:ident) => {{
let frame = $transport.read_frame().await?.ok_or_else(|| {
io::Error::new(
io::ErrorKind::UnexpectedEof,
concat!(
"Transport closed early waiting for frame of type ",
stringify!($type),
),
)
})?;
match utils::deserialize_from_slice::<$type>(frame.as_item()) {
Ok(frame) => frame,
Err(x) => {
if log_enabled!(Level::Trace) {
trace!(
"Failed to deserialize frame item as {}: {:?}",
stringify!($type),
frame.as_item()
);
}
Err(x)?;
unreachable!();
}
}
}};
}
#[async_trait]
impl<T> Authenticate for FramedTransport<T>
where
T: Transport,
{
async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()> {
loop {
trace!("Authenticate::authenticate waiting on next authentication frame");
match next_frame_as!(self, Authentication) {
Authentication::Initialization(x) => {
trace!("Authenticate::Initialization({x:?})");
let response = handler.on_initialization(x).await?;
write_frame!(self, AuthenticationResponse::Initialization(response));
}
Authentication::Challenge(x) => {
trace!("Authenticate::Challenge({x:?})");
let response = handler.on_challenge(x).await?;
write_frame!(self, AuthenticationResponse::Challenge(response));
}
Authentication::Verification(x) => {
trace!("Authenticate::Verify({x:?})");
let response = handler.on_verification(x).await?;
write_frame!(self, AuthenticationResponse::Verification(response));
}
Authentication::Info(x) => {
trace!("Authenticate::Info({x:?})");
handler.on_info(x).await?;
}
Authentication::Error(x) => {
trace!("Authenticate::Error({x:?})");
handler.on_error(x.clone()).await?;
if x.is_fatal() {
return Err(x.into_io_permission_denied());
}
}
Authentication::StartMethod(x) => {
trace!("Authenticate::StartMethod({x:?})");
handler.on_start_method(x).await?;
}
Authentication::Finished => {
trace!("Authenticate::Finished");
handler.on_finished().await?;
return Ok(());
}
}
}
}
}
#[async_trait]
impl<T> Authenticator for FramedTransport<T>
where
T: Transport,
{
async fn initialize(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse> {
trace!("Authenticator::initialize({initialization:?})");
write_frame!(self, Authentication::Initialization(initialization));
let response = next_frame_as!(self, AuthenticationResponse, Initialization);
Ok(response)
}
async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
trace!("Authenticator::challenge({challenge:?})");
write_frame!(self, Authentication::Challenge(challenge));
let response = next_frame_as!(self, AuthenticationResponse, Challenge);
Ok(response)
}
async fn verify(&mut self, verification: Verification) -> io::Result<VerificationResponse> {
trace!("Authenticator::verify({verification:?})");
write_frame!(self, Authentication::Verification(verification));
let response = next_frame_as!(self, AuthenticationResponse, Verification);
Ok(response)
}
async fn info(&mut self, info: Info) -> io::Result<()> {
trace!("Authenticator::info({info:?})");
write_frame!(self, Authentication::Info(info));
Ok(())
}
async fn error(&mut self, error: Error) -> io::Result<()> {
trace!("Authenticator::error({error:?})");
write_frame!(self, Authentication::Error(error));
Ok(())
}
async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
trace!("Authenticator::start_method({start_method:?})");
write_frame!(self, Authentication::StartMethod(start_method));
Ok(())
}
async fn finished(&mut self) -> io::Result<()> {
trace!("Authenticator::finished()");
write_frame!(self, Authentication::Finished);
Ok(())
}
}
#[cfg(test)]
mod tests {
use distant_auth::tests::TestAuthHandler;
use test_log::test;
use tokio::sync::mpsc;
use super::*;
#[test(tokio::test)]
async fn authenticator_initialization_should_be_able_to_successfully_complete_round_trip() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let task = tokio::spawn(async move {
t2.authenticate(TestAuthHandler {
on_initialization: Box::new(|x| Ok(InitializationResponse { methods: x.methods })),
..Default::default()
})
.await
.unwrap()
});
let response = t1
.initialize(Initialization {
methods: vec!["test method".to_string()].into_iter().collect(),
})
.await
.unwrap();
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
assert_eq!(
response,
InitializationResponse {
methods: vec!["test method".to_string()].into_iter().collect()
}
);
}
#[test(tokio::test)]
async fn authenticator_challenge_should_be_able_to_successfully_complete_round_trip() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let task = tokio::spawn(async move {
t2.authenticate(TestAuthHandler {
on_challenge: Box::new(|challenge| {
assert_eq!(
challenge.questions,
vec![Question {
label: "label".to_string(),
text: "text".to_string(),
options: vec![(
"question_key".to_string(),
"question_value".to_string()
)]
.into_iter()
.collect(),
}]
);
assert_eq!(
challenge.options,
vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
);
Ok(ChallengeResponse {
answers: vec!["some answer".to_string()].into_iter().collect(),
})
}),
..Default::default()
})
.await
.unwrap()
});
let response = t1
.challenge(Challenge {
questions: vec![Question {
label: "label".to_string(),
text: "text".to_string(),
options: vec![("question_key".to_string(), "question_value".to_string())]
.into_iter()
.collect(),
}],
options: vec![("key".to_string(), "value".to_string())]
.into_iter()
.collect(),
})
.await
.unwrap();
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
assert_eq!(
response,
ChallengeResponse {
answers: vec!["some answer".to_string()],
}
);
}
#[test(tokio::test)]
async fn authenticator_verification_should_be_able_to_successfully_complete_round_trip() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let task = tokio::spawn(async move {
t2.authenticate(TestAuthHandler {
on_verification: Box::new(|verification| {
assert_eq!(verification.kind, VerificationKind::Host);
assert_eq!(verification.text, "some text");
Ok(VerificationResponse { valid: true })
}),
..Default::default()
})
.await
.unwrap()
});
let response = t1
.verify(Verification {
kind: VerificationKind::Host,
text: "some text".to_string(),
})
.await
.unwrap();
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
assert_eq!(response, VerificationResponse { valid: true });
}
#[test(tokio::test)]
async fn authenticator_info_should_be_able_to_be_sent_to_auth_handler() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(TestAuthHandler {
on_info: Box::new(move |info| {
tx.try_send(info).unwrap();
Ok(())
}),
..Default::default()
})
.await
.unwrap()
});
t1.info(Info {
text: "some text".to_string(),
})
.await
.unwrap();
assert_eq!(
rx.recv().await.unwrap(),
Info {
text: "some text".to_string()
}
);
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
}
#[test(tokio::test)]
async fn authenticator_error_should_be_able_to_be_sent_to_auth_handler() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(TestAuthHandler {
on_error: Box::new(move |error| {
tx.try_send(error).unwrap();
Ok(())
}),
..Default::default()
})
.await
.unwrap()
});
t1.error(Error {
kind: ErrorKind::Error,
text: "some text".to_string(),
})
.await
.unwrap();
assert_eq!(
rx.recv().await.unwrap(),
Error {
kind: ErrorKind::Error,
text: "some text".to_string(),
}
);
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
}
#[test(tokio::test)]
async fn auth_handler_received_error_should_fail_auth_handler_if_fatal() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(TestAuthHandler {
on_error: Box::new(move |error| {
tx.try_send(error).unwrap();
Ok(())
}),
..Default::default()
})
.await
.unwrap()
});
t1.error(Error {
kind: ErrorKind::Fatal,
text: "some text".to_string(),
})
.await
.unwrap();
assert_eq!(
rx.recv().await.unwrap(),
Error {
kind: ErrorKind::Fatal,
text: "some text".to_string(),
}
);
// Verify that the handler exited with an error
task.await.unwrap_err();
}
#[test(tokio::test)]
async fn authenticator_start_method_should_be_able_to_be_sent_to_auth_handler() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(TestAuthHandler {
on_start_method: Box::new(move |start_method| {
tx.try_send(start_method).unwrap();
Ok(())
}),
..Default::default()
})
.await
.unwrap()
});
t1.start_method(StartMethod {
method: "some method".to_string(),
})
.await
.unwrap();
assert_eq!(
rx.recv().await.unwrap(),
StartMethod {
method: "some method".to_string()
}
);
assert!(
!task.is_finished(),
"Auth handler unexpectedly finished without signal"
);
}
#[test(tokio::test)]
async fn authenticator_finished_should_be_able_to_be_sent_to_auth_handler() {
let (mut t1, mut t2) = FramedTransport::test_pair(100);
let (tx, mut rx) = mpsc::channel(1);
let task = tokio::spawn(async move {
t2.authenticate(TestAuthHandler {
on_finished: Box::new(move || {
tx.try_send(()).unwrap();
Ok(())
}),
..Default::default()
})
.await
.unwrap()
});
t1.finished().await.unwrap();
// Verify that the callback was triggered
rx.recv().await.unwrap();
// Finished should signal that the handler completed successfully
task.await.unwrap();
}
}