use std::io; use async_trait::async_trait; use distant_auth::msg::*; use distant_auth::{AuthHandler, Authenticate, Authenticator}; use log::*; use crate::common::{utils, FramedTransport, Transport}; macro_rules! write_frame { ($transport:expr, $data:expr) => {{ let data = utils::serialize_to_vec(&$data)?; if log_enabled!(Level::Trace) { trace!("Writing data as frame: {data:?}"); } $transport.write_frame(data).await? }}; } macro_rules! next_frame_as { ($transport:expr, $type:ident, $variant:ident) => {{ match { next_frame_as!($transport, $type) } { $type::$variant(x) => x, x => { return Err(io::Error::new( io::ErrorKind::InvalidData, format!("Unexpected frame: {x:?}"), )) } } }}; ($transport:expr, $type:ident) => {{ let frame = $transport.read_frame().await?.ok_or_else(|| { io::Error::new( io::ErrorKind::UnexpectedEof, concat!( "Transport closed early waiting for frame of type ", stringify!($type), ), ) })?; match utils::deserialize_from_slice::<$type>(frame.as_item()) { Ok(frame) => frame, Err(x) => { if log_enabled!(Level::Trace) { trace!( "Failed to deserialize frame item as {}: {:?}", stringify!($type), frame.as_item() ); } Err(x)?; unreachable!(); } } }}; } #[async_trait] impl Authenticate for FramedTransport where T: Transport, { async fn authenticate(&mut self, mut handler: impl AuthHandler + Send) -> io::Result<()> { loop { trace!("Authenticate::authenticate waiting on next authentication frame"); match next_frame_as!(self, Authentication) { Authentication::Initialization(x) => { trace!("Authenticate::Initialization({x:?})"); let response = handler.on_initialization(x).await?; write_frame!(self, AuthenticationResponse::Initialization(response)); } Authentication::Challenge(x) => { trace!("Authenticate::Challenge({x:?})"); let response = handler.on_challenge(x).await?; write_frame!(self, AuthenticationResponse::Challenge(response)); } Authentication::Verification(x) => { trace!("Authenticate::Verify({x:?})"); let response = handler.on_verification(x).await?; write_frame!(self, AuthenticationResponse::Verification(response)); } Authentication::Info(x) => { trace!("Authenticate::Info({x:?})"); handler.on_info(x).await?; } Authentication::Error(x) => { trace!("Authenticate::Error({x:?})"); handler.on_error(x.clone()).await?; if x.is_fatal() { return Err(x.into_io_permission_denied()); } } Authentication::StartMethod(x) => { trace!("Authenticate::StartMethod({x:?})"); handler.on_start_method(x).await?; } Authentication::Finished => { trace!("Authenticate::Finished"); handler.on_finished().await?; return Ok(()); } } } } } #[async_trait] impl Authenticator for FramedTransport where T: Transport, { async fn initialize( &mut self, initialization: Initialization, ) -> io::Result { trace!("Authenticator::initialize({initialization:?})"); write_frame!(self, Authentication::Initialization(initialization)); let response = next_frame_as!(self, AuthenticationResponse, Initialization); Ok(response) } async fn challenge(&mut self, challenge: Challenge) -> io::Result { trace!("Authenticator::challenge({challenge:?})"); write_frame!(self, Authentication::Challenge(challenge)); let response = next_frame_as!(self, AuthenticationResponse, Challenge); Ok(response) } async fn verify(&mut self, verification: Verification) -> io::Result { trace!("Authenticator::verify({verification:?})"); write_frame!(self, Authentication::Verification(verification)); let response = next_frame_as!(self, AuthenticationResponse, Verification); Ok(response) } async fn info(&mut self, info: Info) -> io::Result<()> { trace!("Authenticator::info({info:?})"); write_frame!(self, Authentication::Info(info)); Ok(()) } async fn error(&mut self, error: Error) -> io::Result<()> { trace!("Authenticator::error({error:?})"); write_frame!(self, Authentication::Error(error)); Ok(()) } async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> { trace!("Authenticator::start_method({start_method:?})"); write_frame!(self, Authentication::StartMethod(start_method)); Ok(()) } async fn finished(&mut self) -> io::Result<()> { trace!("Authenticator::finished()"); write_frame!(self, Authentication::Finished); Ok(()) } } #[cfg(test)] mod tests { use distant_auth::tests::TestAuthHandler; use test_log::test; use tokio::sync::mpsc; use super::*; #[test(tokio::test)] async fn authenticator_initialization_should_be_able_to_successfully_complete_round_trip() { let (mut t1, mut t2) = FramedTransport::test_pair(100); let task = tokio::spawn(async move { t2.authenticate(TestAuthHandler { on_initialization: Box::new(|x| Ok(InitializationResponse { methods: x.methods })), ..Default::default() }) .await .unwrap() }); let response = t1 .initialize(Initialization { methods: vec!["test method".to_string()].into_iter().collect(), }) .await .unwrap(); assert!( !task.is_finished(), "Auth handler unexpectedly finished without signal" ); assert_eq!( response, InitializationResponse { methods: vec!["test method".to_string()].into_iter().collect() } ); } #[test(tokio::test)] async fn authenticator_challenge_should_be_able_to_successfully_complete_round_trip() { let (mut t1, mut t2) = FramedTransport::test_pair(100); let task = tokio::spawn(async move { t2.authenticate(TestAuthHandler { on_challenge: Box::new(|challenge| { assert_eq!( challenge.questions, vec![Question { label: "label".to_string(), text: "text".to_string(), options: vec![( "question_key".to_string(), "question_value".to_string() )] .into_iter() .collect(), }] ); assert_eq!( challenge.options, vec![("key".to_string(), "value".to_string())] .into_iter() .collect(), ); Ok(ChallengeResponse { answers: vec!["some answer".to_string()].into_iter().collect(), }) }), ..Default::default() }) .await .unwrap() }); let response = t1 .challenge(Challenge { questions: vec![Question { label: "label".to_string(), text: "text".to_string(), options: vec![("question_key".to_string(), "question_value".to_string())] .into_iter() .collect(), }], options: vec![("key".to_string(), "value".to_string())] .into_iter() .collect(), }) .await .unwrap(); assert!( !task.is_finished(), "Auth handler unexpectedly finished without signal" ); assert_eq!( response, ChallengeResponse { answers: vec!["some answer".to_string()], } ); } #[test(tokio::test)] async fn authenticator_verification_should_be_able_to_successfully_complete_round_trip() { let (mut t1, mut t2) = FramedTransport::test_pair(100); let task = tokio::spawn(async move { t2.authenticate(TestAuthHandler { on_verification: Box::new(|verification| { assert_eq!(verification.kind, VerificationKind::Host); assert_eq!(verification.text, "some text"); Ok(VerificationResponse { valid: true }) }), ..Default::default() }) .await .unwrap() }); let response = t1 .verify(Verification { kind: VerificationKind::Host, text: "some text".to_string(), }) .await .unwrap(); assert!( !task.is_finished(), "Auth handler unexpectedly finished without signal" ); assert_eq!(response, VerificationResponse { valid: true }); } #[test(tokio::test)] async fn authenticator_info_should_be_able_to_be_sent_to_auth_handler() { let (mut t1, mut t2) = FramedTransport::test_pair(100); let (tx, mut rx) = mpsc::channel(1); let task = tokio::spawn(async move { t2.authenticate(TestAuthHandler { on_info: Box::new(move |info| { tx.try_send(info).unwrap(); Ok(()) }), ..Default::default() }) .await .unwrap() }); t1.info(Info { text: "some text".to_string(), }) .await .unwrap(); assert_eq!( rx.recv().await.unwrap(), Info { text: "some text".to_string() } ); assert!( !task.is_finished(), "Auth handler unexpectedly finished without signal" ); } #[test(tokio::test)] async fn authenticator_error_should_be_able_to_be_sent_to_auth_handler() { let (mut t1, mut t2) = FramedTransport::test_pair(100); let (tx, mut rx) = mpsc::channel(1); let task = tokio::spawn(async move { t2.authenticate(TestAuthHandler { on_error: Box::new(move |error| { tx.try_send(error).unwrap(); Ok(()) }), ..Default::default() }) .await .unwrap() }); t1.error(Error { kind: ErrorKind::Error, text: "some text".to_string(), }) .await .unwrap(); assert_eq!( rx.recv().await.unwrap(), Error { kind: ErrorKind::Error, text: "some text".to_string(), } ); assert!( !task.is_finished(), "Auth handler unexpectedly finished without signal" ); } #[test(tokio::test)] async fn auth_handler_received_error_should_fail_auth_handler_if_fatal() { let (mut t1, mut t2) = FramedTransport::test_pair(100); let (tx, mut rx) = mpsc::channel(1); let task = tokio::spawn(async move { t2.authenticate(TestAuthHandler { on_error: Box::new(move |error| { tx.try_send(error).unwrap(); Ok(()) }), ..Default::default() }) .await .unwrap() }); t1.error(Error { kind: ErrorKind::Fatal, text: "some text".to_string(), }) .await .unwrap(); assert_eq!( rx.recv().await.unwrap(), Error { kind: ErrorKind::Fatal, text: "some text".to_string(), } ); // Verify that the handler exited with an error task.await.unwrap_err(); } #[test(tokio::test)] async fn authenticator_start_method_should_be_able_to_be_sent_to_auth_handler() { let (mut t1, mut t2) = FramedTransport::test_pair(100); let (tx, mut rx) = mpsc::channel(1); let task = tokio::spawn(async move { t2.authenticate(TestAuthHandler { on_start_method: Box::new(move |start_method| { tx.try_send(start_method).unwrap(); Ok(()) }), ..Default::default() }) .await .unwrap() }); t1.start_method(StartMethod { method: "some method".to_string(), }) .await .unwrap(); assert_eq!( rx.recv().await.unwrap(), StartMethod { method: "some method".to_string() } ); assert!( !task.is_finished(), "Auth handler unexpectedly finished without signal" ); } #[test(tokio::test)] async fn authenticator_finished_should_be_able_to_be_sent_to_auth_handler() { let (mut t1, mut t2) = FramedTransport::test_pair(100); let (tx, mut rx) = mpsc::channel(1); let task = tokio::spawn(async move { t2.authenticate(TestAuthHandler { on_finished: Box::new(move || { tx.try_send(()).unwrap(); Ok(()) }), ..Default::default() }) .await .unwrap() }); t1.finished().await.unwrap(); // Verify that the callback was triggered rx.recv().await.unwrap(); // Finished should signal that the handler completed successfully task.await.unwrap(); } }